# -------------------------------------------------------------------------------------
# tf_train.py
# tensorflow application
# this is very slow! Vectorization is required
# -------------------------------------------------------------------------------------

import ROOT
import tensorflow as tf
import numpy as np
import math
from array import array

# Load the model and the scaling parameters
print("Loading Neural Network model...")
model = tf.keras.models.load_model("aleph_nn.h5")

print("Loading scaling parameters...")
params = np.load("scaling_params.npz")
mean = params['mean']
std  = params['std']

def get_nn_score(input_values):
    """
    Inputs: List or array of 8 variables [p_K1, p_K2, d0_K1, d0_K2, ... p_phi]
    Returns: A single score between 0 (Background) and 1 (Signal)
    """
    # Convert input to numpy array
    vals = np.array(input_values, dtype=np.float32)

    # IMPORTANT: Manual scaling using training parameters
    # This must be identical to the training preprocessing
    vals_scaled = (vals - mean) / std

    # Reshape for TensorFlow (expects batch dimension: 1, 8)
    vals_reshaped = vals_scaled.reshape(1, -1)

    # Get prediction
    prediction = model.predict(vals_reshaped, batch_size=4096, verbose=0)
    return prediction[0][0]
# ---------------------------------------------------------------------------------------


'''
# --- Testing with dummy data ---
# Suppose we have a candidate event with these 8 values:
# [p_K1, p_K2, d0_K1, d0_K2, dedx_K1, dedx_K2, angle_K12, p_phi]
dummy_event = [15.2, 12.8, 0.005, 0.008, 0.5, -0.2, 27.5]

score = get_nn_score(dummy_event)

print("-" * 30)
print(f"Input Variables: {dummy_event}")
print(f"Neural Network Score: {score:.4f}")

if score > 0.5:
    print("Result: Likely Signal")
else:
    print("Result: Likely Background")
'''

MonteCarlo = False

print("Testing TensorFlow ...")

# --- Chain setup ---
kaonMass =  0.493677 # GeV from pdg

# --- Chain setup ---
Event = ROOT.TChain("EventInfo")
Track = ROOT.TChain("Track")
if MonteCarlo:
  Truth = ROOT.TChain("Truth")
  Truth.Add("/hepdata/Aleph/AlephMC9*.root")
  Event.Add("/hepdata/Aleph/AlephMC9*.root")
  Track.Add("/hepdata/Aleph/AlephMC9*.root")
  fileout = ROOT.TFile("finalMC.root","recreate")
else:
  Event.Add("/hepdata/Aleph/AlephDA9*.root")
  Track.Add("/hepdata/Aleph/AlephDA9*.root")
  fileout = ROOT.TFile("finalData.root","recreate")

# --- variables Settings ---
p_K1      = array('f', [0.])
p_K2      = array('f', [0.])
d0_K1     = array('f', [0.])
d0_K2     = array('f', [0.])
dedx_K1   = array('f', [0.])
dedx_K2   = array('f', [0.])
p_phi     = array('f', [0.])

# --- Histograms ---
hmKK_std_all = ROOT.TH1F("hmKK_std_all", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)
hmKK_std_sig = ROOT.TH1F("hmKK_std_sig", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)
hmKK_bdt_all = ROOT.TH1F("hmKK_bdt_all", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)
hmKK_bdt_sig = ROOT.TH1F("hmKK_bdt_sig", ";m(K+,K-) [GeV];Entries", 100, 0.98, 1.10)

# --- Event loop ---
fraction = 0.001
nentries = int(fraction * Track.GetEntries())
print(f"Nevents: {nentries}")

for ev in range(nentries):
    Event.GetEntry(ev)
    Track.GetEntry(ev)
    if MonteCarlo:
       Truth.GetEntry(ev)

    # Select good hadronic events (e+ e- -> Z -> q qbar)
    isGoodEvent = list(Event.evt_goodevent)
    if isGoodEvent[0]==False: continue

    if ev % 10 == 0:
        pct = math.ceil(100.0 * ev / nentries)
        print(f"{ev} / {nentries} [{pct}%]", end="\r", flush=True)

    # Read branches from main data files
    px = list(Track.trk_px)
    py = list(Track.trk_py)
    pz = list(Track.trk_pz)
    ch = list(Track.trk_chg)
    d0 = list(Track.trk_d0)
    dedxok = list(Track.trk_dedxok)
    dedx = list(Track.trk_dedx)
    dedxkaexp = list(Track.trk_dedxkaexp)
    dedxkasig = list(Track.trk_dedxkasig)
    tind = list(Track.trk_truthindex)

    if MonteCarlo:
       id    = list(Truth.tru_id)
       m1id  = list(Truth.tru_m1id)
       m1brc = list(Truth.tru_m1brc)
       m1ndau= list(Truth.tru_m1ndau)

    # track list
    trklist = []

    # Build four-vectors
    for i in range(len(px)):
        # 4-vector
        p = ROOT.TLorentzVector()
        p.SetXYZM(px[i], py[i], pz[i], kaonMass)
        if p.P() < 1.0: continue
        # dE/dx information based on kaon hypothesis
        chi_dedx = 0
        if dedxok[i]:
           chi_dedx = (dedx[i] - dedxkaexp[i]) / dedxkasig[i]

        # truth info
        parid, motherid, motherbrc,motherndau = 0,0,0,0
        if MonteCarlo and tind[i] < 999:
           parid      = id[tind[i]]
           motherid   = m1id[tind[i]]
           motherbrc  = m1brc[tind[i]]
           motherndau = m1ndau[tind[i]]

        trklist.append((p,ch[i],chi_dedx,d0[i], parid,motherid,motherbrc,motherndau))

    # Combine + and - pairs (to get phi -> K+ K- candidates)
    for i in range(len(trklist)):
        if trklist[i][1] == -1: continue
        for j in range(len(trklist)):
            if trklist[j][1] == +1: continue

            t1 = trklist[i][0]
            t2 = trklist[j][0]

            phi = t1 + t2 # add 4-vectors
            mass = phi.M()
            if mass > 1.2: continue

            phiSignal = trklist[i][4] == 11 and trklist[j][4] == 12 and \
                        trklist[i][6] == trklist[j][6] and \
                        trklist[i][5] == 90 and trklist[i][7] == 2

            p_K1[0]     = t1.P()
            p_K2[0]     = t2.P()
            dedx_K1[0]  = abs(trklist[i][2])
            dedx_K2[0]  = abs(trklist[j][2])
            d0_K1[0]    = abs(trklist[i][3])
            d0_K2[0]    = abs(trklist[j][3])
            p_phi[0]    = phi.P()

            # [p_K1, p_K2, d0_K1, d0_K2, dedx_K1, dedx_K2, angle_K12, p_phi]
            current_data = [p_K1[0], p_K2[0], d0_K1[0], d0_K2[0], dedx_K1[0], dedx_K2[0], p_phi[0]]
            nn_score = get_nn_score(current_data)

            hmKK_std_all.Fill(mass)
            if phiSignal:
               hmKK_std_sig.Fill(mass)

            if nn_score > 0.5: # this cut must be optimized
               hmKK_bdt_all.Fill(mass)
               if phiSignal:
                  hmKK_bdt_sig.Fill(mass)

# End of event loop

fileout.Write();
fileout.Close();

print("\nEvent loop completed.")
