# -------------------------------------------------------------------------------------
# tf_train.py
# tensorflow application
# -------------------------------------------------------------------------------------

import ROOT
import numpy as np
import tensorflow as tf

# 1. Configuration
variables = ["p_K1", "p_K2", "d0_K1", "d0_K2", "dedx_K1", "dedx_K2", "p_phi"]

def get_data(file_path, tree_name):
    df = ROOT.RDataFrame(tree_name, file_path)
    npy = df.AsNumpy(columns=variables)
    return np.column_stack([npy[v] for v in variables])

# 2. Load and Label
print("Loading ROOT files...")
x_sig = get_data("signal.root", "Events")
x_bkg = get_data("background.root", "Events")

# Combine and create labels
x_all = np.vstack([x_sig, x_bkg]).astype(np.float32)
y_all = np.hstack([np.ones(len(x_sig)), np.zeros(len(x_bkg))]).astype(np.float32)

# 3. Manual Shuffling
# Important: Signal and background are currently blocks; we must mix them
indices = np.arange(len(x_all))
np.random.shuffle(indices)
x_all = x_all[indices]
y_all = y_all[indices]

# 4. Manual Feature Scaling (Standardization)
# We calculate Mean and StdDev manually: (x - mean) / std
mean = np.mean(x_all, axis=0)
std  = np.std(x_all, axis=0)
# Avoid division by zero if a variable is constant
std[std == 0] = 1.0

x_scaled = (x_all - mean) / std

# 5. Manual Train/Test Split (e.g., 80% train, 20% test)
split_idx = int(0.8 * len(x_scaled))
x_train, x_test = x_scaled[:split_idx], x_scaled[split_idx:]
y_train, y_test = y_all[:split_idx], y_all[split_idx:]

# 6. Build Model
model = tf.keras.Sequential([
    tf.keras.layers.Dense(16, activation='relu', input_shape=(len(variables),)),
    tf.keras.layers.Dense( 8, activation='relu'),
    tf.keras.layers.Dense( 1, activation='sigmoid')
])

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 7. Train
print("Starting TensorFlow training loop...")
model.fit(x_train, y_train, epochs=200, batch_size=512, validation_data=(x_test, y_test))

# 8. Save Model and Scaling Parameters
# You MUST save the mean and std to scale your Aleph data the same way later
model.save("aleph_nn.h5")
np.savez("scaling_params.npz", mean=mean, std=std)

print("\nModel and scaling parameters saved.")
