# -------------------------------------------------------------------------------------
# tmva_datasets.py
# -------------------------------------------------------------------------------------

import ROOT
import math
from array import array

print("[TMVA EXAMPLE] Generating datasets ...")

rnd = ROOT.TRandom(123456)
kaonMass =  0.493677 # GeV from pdg

# Define branch variables as numpy arrays
file_sig = ROOT.TFile("signal.root", "recreate")
tree_sig = ROOT.TTree("Events","signal tree")
sp_K1      = array('f', [0.])
sp_K2      = array('f', [0.])
sd0_K1     = array('f', [0.])
sd0_K2     = array('f', [0.])
sdedx_K1   = array('f', [0.])
sdedx_K2   = array('f', [0.])
sp_phi     = array('f', [0.])
tree_sig.Branch("p_K1",sp_K1,"p_K1/F")
tree_sig.Branch("p_K2",sp_K2,"p_K2/F")
tree_sig.Branch("d0_K1",sd0_K1,"d0_K1/F")
tree_sig.Branch("d0_K2",sd0_K2,"d0_K2/F")
tree_sig.Branch("dedx_K1",sdedx_K1,"dedx_K1/F")
tree_sig.Branch("dedx_K2",sdedx_K2,"dedx_K2/F")
tree_sig.Branch("p_phi",sp_phi,"p_phi/F")

file_bgr = ROOT.TFile("background.root", "recreate")
tree_bgr = ROOT.TTree("Events","background tree")
bp_K1      = array('f', [0.])
bp_K2      = array('f', [0.])
bd0_K1     = array('f', [0.])
bd0_K2     = array('f', [0.])
bdedx_K1   = array('f', [0.])
bdedx_K2   = array('f', [0.])
bp_phi     = array('f', [0.])
tree_bgr.Branch("p_K1",bp_K1,"p_K1/F")
tree_bgr.Branch("p_K2",bp_K2,"p_K2/F")
tree_bgr.Branch("d0_K1",bd0_K1,"d0_K1/F")
tree_bgr.Branch("d0_K2",bd0_K2,"d0_K2/F")
tree_bgr.Branch("dedx_K1",bdedx_K1,"dedx_K1/F")
tree_bgr.Branch("dedx_K2",bdedx_K2,"dedx_K2/F")
tree_bgr.Branch("p_phi",bp_phi,"p_phi/F")

# --- Chain setup ---
Event = ROOT.TChain("EventInfo")
Track = ROOT.TChain("Track")
Truth = ROOT.TChain("Truth")
Event.Add("/hepdata/Aleph/AlephMC9*.root")
Track.Add("/hepdata/Aleph/AlephMC9*.root")
Truth.Add("/hepdata/Aleph/AlephMC9*.root")

# --- Event loop ---
fraction = 0.1
nentries = int(fraction * Track.GetEntries())
print(f"Nevents: {nentries}")

for ev in range(nentries):
    Event.GetEntry(ev)
    Track.GetEntry(ev)
    Truth.GetEntry(ev)

    # Select good hadronic events (e+ e- -> Z -> q qbar)
    isGoodEvent = list(Event.evt_goodevent)
    if isGoodEvent[0]==False: continue

    if ev % 1000 == 0:
        pct = math.ceil(100.0 * ev / nentries)
        print(f"{ev} / {nentries} [{pct}%]", end="\r", flush=True)

    # Read branches from main data files
    px = list(Track.trk_px)
    py = list(Track.trk_py)
    pz = list(Track.trk_pz)
    ch = list(Track.trk_chg)
    d0 = list(Track.trk_d0)
    dedxok = list(Track.trk_dedxok)
    dedx = list(Track.trk_dedx)
    dedxkaexp = list(Track.trk_dedxkaexp)
    dedxkasig = list(Track.trk_dedxkasig)
    tind = list(Track.trk_truthindex)

    id    = list(Truth.tru_id)
    m1id  = list(Truth.tru_m1id)
    m1brc = list(Truth.tru_m1brc)
    m1ndau= list(Truth.tru_m1ndau)

    # track list
    trklist = []

    # Build four-vectors
    for i in range(len(px)):
        # 4-vector
        p = ROOT.TLorentzVector()
        p.SetXYZM(px[i], py[i], pz[i], kaonMass)
        if p.P() < 1.0: continue
        # dE/dx information based on kaon hypothesis
        chi_dedx = 0
        if dedxok[i]:
           chi_dedx = (dedx[i] - dedxkaexp[i]) / dedxkasig[i]

        # truth info
        parid, motherid, motherbrc,motherndau = 0,0,0,0
        if tind[i] < 999:
           parid      = id[tind[i]]
           motherid   = m1id[tind[i]]
           motherbrc  = m1brc[tind[i]]
           motherndau = m1ndau[tind[i]]

        trklist.append((p,ch[i],chi_dedx,d0[i], parid,motherid,motherbrc,motherndau))

    # Combine + and - pairs (to get phi -> K+ K- candidates)
    for i in range(len(trklist)):
        if trklist[i][1] == -1: continue
        for j in range(len(trklist)):
            if trklist[j][1] == +1: continue

            t1 = trklist[i][0]
            t2 = trklist[j][0]

            phi = t1 + t2 # add 4-vectors
            mass = phi.M()
            if mass > 1.2: continue

            phiSignal = trklist[i][4] == 11 and trklist[j][4] == 12 and \
                        trklist[i][6] == trklist[j][6] and \
                        trklist[i][5] == 90 and trklist[i][7] == 2

            if phiSignal:
               sp_K1[0]      = t1.P()
               sp_K2[0]      = t2.P()
               sdedx_K1[0]   = abs(trklist[i][2])
               sdedx_K2[0]   = abs(trklist[j][2])
               sd0_K1[0]     = abs(trklist[i][3])
               sd0_K2[0]     = abs(trklist[j][3])
               sp_phi[0]     = phi.P()
               tree_sig.Fill()
            else:
               if rnd.Uniform() < 0.007:
                  bp_K1[0]      = t1.P()
                  bp_K2[0]      = t2.P()
                  bdedx_K1[0]   = abs(trklist[i][2])
                  bdedx_K2[0]   = abs(trklist[j][2])
                  bd0_K1[0]     = abs(trklist[i][3])
                  bd0_K2[0]     = abs(trklist[j][3])
                  bp_phi[0]     = phi.P()
                  tree_bgr.Fill()

# End of event loop
file_sig.Write()
file_sig.Close()
file_bgr.Write()
file_bgr.Close()
print("\nDatasets have been generated.")
