# -------------------------------------------------------------------------------------
# tmva_test.py
# -------------------------------------------------------------------------------------

import ROOT
import math
from array import array

MonteCarlo = True

print("[TMVA EXAMPLE] Testing ...")

# --- Chain setup ---
kaonMass =  0.493677 # GeV from pdg

# --- Chain setup ---
Event = ROOT.TChain("EventInfo")
Track = ROOT.TChain("Track")
if MonteCarlo:
  Truth = ROOT.TChain("Truth")
  Truth.Add("/hepdata/Aleph/AlephMC9*.root")
  Event.Add("/hepdata/Aleph/AlephMC9*.root")
  Track.Add("/hepdata/Aleph/AlephMC9*.root")
  fileout = ROOT.TFile("finalMC.root","recreate")
else:
  Event.Add("/hepdata/Aleph/AlephDA9*.root")
  Track.Add("/hepdata/Aleph/AlephDA9*.root")
  fileout = ROOT.TFile("finalData.root","recreate")

# --- TMVA Settings ---
p_K1      = array('f', [0.])
p_K2      = array('f', [0.])
d0_K1     = array('f', [0.])
d0_K2     = array('f', [0.])
dedx_K1   = array('f', [0.])
dedx_K2   = array('f', [0.])
p_phi     = array('f', [0.])
reader = ROOT.TMVA.Reader("!Color:!Silent")
reader.AddVariable("p_K1",  p_K1)
reader.AddVariable("p_K2",  p_K2)
reader.AddVariable("d0_K1", d0_K1)
reader.AddVariable("d0_K2", d0_K2)
reader.AddVariable("dedx_K1", dedx_K1)
reader.AddVariable("dedx_K2", dedx_K2)
reader.AddVariable("p_phi", p_phi)
reader.BookMVA(ROOT.TMVA.Types.kBDT, "dataset/weights/TMVAClassification_BDT.weights.xml")

# --- Histograms ---
hmKK_std_all = ROOT.TH1F("hmKK_std_all", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)
hmKK_std_sig = ROOT.TH1F("hmKK_std_sig", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)
hmKK_bdt_all = ROOT.TH1F("hmKK_bdt_all", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)
hmKK_bdt_sig = ROOT.TH1F("hmKK_bdt_sig", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)

# --- Event loop ---
fraction = 1
nentries = int(fraction * Track.GetEntries())
print(f"Nevents: {nentries}")

for ev in range(nentries):
    Event.GetEntry(ev)
    Track.GetEntry(ev)
    if MonteCarlo:
       Truth.GetEntry(ev)

    # Select good hadronic events (e+ e- -> Z -> q qbar)
    isGoodEvent = list(Event.evt_goodevent)
    if isGoodEvent[0]==False: continue

    if ev % 1000 == 0:
        pct = math.ceil(100.0 * ev / nentries)
        print(f"{ev} / {nentries} [{pct}%]", end="\r", flush=True)

    # Read branches from main data files
    px = list(Track.trk_px)
    py = list(Track.trk_py)
    pz = list(Track.trk_pz)
    ch = list(Track.trk_chg)
    d0 = list(Track.trk_d0)
    dedxok = list(Track.trk_dedxok)
    dedx = list(Track.trk_dedx)
    dedxkaexp = list(Track.trk_dedxkaexp)
    dedxkasig = list(Track.trk_dedxkasig)
    tind = list(Track.trk_truthindex)

    if MonteCarlo:
       id    = list(Truth.tru_id)
       m1id  = list(Truth.tru_m1id)
       m1brc = list(Truth.tru_m1brc)
       m1ndau= list(Truth.tru_m1ndau)

    # track list
    trklist = []

    # Build four-vectors
    for i in range(len(px)):
        # 4-vector
        p = ROOT.TLorentzVector()
        p.SetXYZM(px[i], py[i], pz[i], kaonMass)
        if p.P() < 1.0: continue
        # dE/dx information based on kaon hypothesis
        chi_dedx = 0
        if dedxok[i]:
           chi_dedx = (dedx[i] - dedxkaexp[i]) / dedxkasig[i]

        # truth info
        parid, motherid, motherbrc,motherndau = 0,0,0,0
        if MonteCarlo and tind[i] < 999:
           parid      = id[tind[i]]
           motherid   = m1id[tind[i]]
           motherbrc  = m1brc[tind[i]]
           motherndau = m1ndau[tind[i]]

        trklist.append((p,ch[i],chi_dedx,d0[i], parid,motherid,motherbrc,motherndau))

    # Combine + and - pairs (to get phi -> K+ K- candidates)
    for i in range(len(trklist)):
        if trklist[i][1] == -1: continue
        for j in range(len(trklist)):
            if trklist[j][1] == +1: continue

            t1 = trklist[i][0]
            t2 = trklist[j][0]

            phi = t1 + t2 # add 4-vectors
            mass = phi.M()
            if mass > 1.2: continue

            phiSignal = trklist[i][4] == 11 and trklist[j][4] == 12 and \
                        trklist[i][6] == trklist[j][6] and \
                        trklist[i][5] == 90 and trklist[i][7] == 2

            p_K1[0]     = t1.P()
            p_K2[0]     = t2.P()
            dedx_K1[0]  = abs(trklist[i][2])
            dedx_K2[0]  = abs(trklist[j][2])
            d0_K1[0]    = abs(trklist[i][3])
            d0_K2[0]    = abs(trklist[j][3])
            p_phi[0]    = phi.P()

            bdt_score = reader.EvaluateMVA(ROOT.TMVA.Types.kBDT)

            #print("BDT Score: ",bdt_score)

            hmKK_std_all.Fill(mass)
            if phiSignal:
               hmKK_std_sig.Fill(mass)

            if bdt_score > -0.06:
               hmKK_bdt_all.Fill(mass)
               if phiSignal:
                  hmKK_bdt_sig.Fill(mass)

# End of event loop

fileout.Write();
fileout.Close();

print("\nEvent loop completed.")
