# -------------------------------------------------------------------------------------
# skeleton_aleph_phi2KK.py
#
# Example PyROOT to read/analyse Aleph files
# for analysing K -> K+ K- events (for both Monte Carlo and Recorded data)
# -------------------------------------------------------------------------------------

import ROOT
import math

MonteCarlo = True

# --- Enable multi-threading ---
ROOT.EnableImplicitMT()

# --- Chain setup ---
Event = ROOT.TChain("EventInfo")
Gamma = ROOT.TChain("Gamma")
Track = ROOT.TChain("Track")
Truth = ROOT.TChain("Truth")

if MonteCarlo:
  print("Running Aleph Monte Carlo")
  fileout = ROOT.TFile("histoMC.root", "recreate")
  Event.Add("/hepdata/Aleph/AlephMC91.root")
  Gamma.Add("/hepdata/Aleph/AlephMC91.root")
  Track.Add("/hepdata/Aleph/AlephMC91.root")
  Truth.Add("/hepdata/Aleph/AlephMC91.root")
else:
  print("Running Aleph Real Data")
  fileout = ROOT.TFile("histoDA.root", "recreate")
  Event.Add("/hepdata/Aleph/AlephDA91.root")
  Gamma.Add("/hepdata/Aleph/AlephDA91.root")
  Track.Add("/hepdata/Aleph/AlephDA91.root")

# --- Histograms ---
hmKK_all = ROOT.TH1F("hmKK_all", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)
hmKK_sig = ROOT.TH1F("hmKK_sig", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)
hmKK_bgr = ROOT.TH1F("hmKK_bgr", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)
hdedx_all  = ROOT.TH1F("hdedx_all", ";chi(dE/dx);Entries", 100, -5, 5)
hdedx_kaon = ROOT.TH1F("hdedx_kaon", ";chi(dE/dx);Entries", 100, -5, 5)

kaonMass =  0.493677 # GeV from pdg


# --- Event loop ---
fraction = 0.05
nentries = int(fraction * Gamma.GetEntries())
print(f"Nevents: {nentries}")


for ev in range(nentries):
    Event.GetEntry(ev)
    Track.GetEntry(ev)
    if MonteCarlo:
       Truth.GetEntry(ev)

    # Select good hadronic events (e+ e- -> Z -> q qbar)
    isGoodEvent = list(Event.evt_goodevent)
    if isGoodEvent[0]==False: continue

    if ev % 5000 == 0:
        pct = math.ceil(100.0 * ev / nentries)
        print(f"{ev} / {nentries} [{pct}%]", end="\r", flush=True)

    # Read branches as Python lists
    px = list(Track.trk_px)
    py = list(Track.trk_py)
    pz = list(Track.trk_pz)
    ch = list(Track.trk_chg)
    dedxok = list(Track.trk_dedxok)
    dedx = list(Track.trk_dedx)
    dedxkaexp = list(Track.trk_dedxkaexp)
    dedxkasig = list(Track.trk_dedxkasig)
    tind = list(Track.trk_truthindex)

    if MonteCarlo:
       id    = list(Truth.tru_id)
       m1id  = list(Truth.tru_m1id)
       m1brc = list(Truth.tru_m1brc)
       m1ndau= list(Truth.tru_m1ndau)

    # photon list
    trklist = []

    # Build four-vectors
    for i in range(len(px)):
        # 4-vector
        p = ROOT.TLorentzVector()
        p.SetXYZM(px[i], py[i], pz[i], kaonMass)
        if p.Pt()<1.5: continue
        # dE/dx information based on kaon hypothesis
        chi_dedx = 0
        if dedxok[i]:
           chi_dedx = (dedx[i] - dedxkaexp[i]) / dedxkasig[i]
        hdedx_all.Fill(chi_dedx)


        # truth info
        parid,motherid, motherbrc,motherndau = 0,0,0,0
        if MonteCarlo and tind[i] < 999:
              parid      = id[tind[i]]
              motherid   = m1id[tind[i]]
              motherbrc  = m1brc[tind[i]]
              motherndau = m1ndau[tind[i]]

        trklist.append((p,ch[i],chi_dedx, parid,motherid,motherbrc,motherndau))


    # Combine + and - pairs (to get phi -> K+ K- candidates)
    for i in range(len(trklist)):
        if trklist[i][1] == -1: continue
        for j in range(len(trklist)):
            if trklist[j][1] == +1: continue
            phi = trklist[i][0] + trklist[j][0] # add 4-vectors
            mass = phi.M()
            phiSignal = trklist[i][3] == 11 and trklist[j][3] == 12 and \
                        trklist[i][5] == trklist[j][5] and \
                        trklist[i][4] == 90 and trklist[i][6] == 2
            hmKK_all.Fill(mass)
            if phiSignal:
                  hmKK_sig.Fill(mass) # signal
                  hdedx_kaon.Fill(trklist[i][2])
                  hdedx_kaon.Fill(trklist[j][2])
            else:
                  hmKK_bgr.Fill(mass) # background

# End of event loop
print("\nEvent loop completed.")

hmKK_all.Write()
hmKK_sig.Write()
hmKK_bgr.Write()
hdedx_all.Write()
hdedx_kaon.Write()
fileout.Close()

