# -------------------------------------------------------------------------------------
# tmva_train.py
#
# use the following command to see detailed TMVA plots
# $ root -l -e 'TMVA::TMVAGui("TMVA_Output.root")'
# -------------------------------------------------------------------------------------


import ROOT

print("[RMVA EXAMPLE] Training with BDT ...")

# 1. Create the output file
out_file = ROOT.TFile.Open("TMVA_Output.root", "RECREATE")

# 2. Initialize Factory and DataLoader
# The first argument is the base name for the weight files
factory = ROOT.TMVA.Factory("TMVAClassification", out_file,
                            "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Classification")

dataloader = ROOT.TMVA.DataLoader("dataset")

# 3. Define the input variables
# Format: dataloader.AddVariable("name", 'type')
dataloader.AddVariable("p_K1", 'F')
dataloader.AddVariable("p_K2", 'F')
dataloader.AddVariable("d0_K1", 'F')
dataloader.AddVariable("d0_K2", 'F')
dataloader.AddVariable("dedx_K1", 'F')
dataloader.AddVariable("dedx_K2", 'F')
dataloader.AddVariable("p_phi", 'F')

# 4. Open the ROOT files and get the Trees
input_sig = ROOT.TFile.Open("signal.root")
input_bkg = ROOT.TFile.Open("background.root")

signal_tree = input_sig.Get("Events")
background_tree = input_bkg.Get("Events")

# 5. Add trees to the dataloader
dataloader.AddSignalTree(signal_tree, 1.0)
dataloader.AddBackgroundTree(background_tree, 1.0)

# 6. Prepare training and test trees
dataloader.PrepareTrainingAndTestTree("", "SplitMode=Random:NormMode=NumEvents:!V")

# 7. Book the BDT Method
factory.BookMethod(dataloader, ROOT.TMVA.Types.kBDT, "BDT",
                   "!H:!V:NTrees=850:MinNodeSize=2.5%:MaxDepth=3:BoostType=AdaBoost:"
                   "AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:"
                   "SeparationType=GiniIndex:nCuts=20")

# 8. Run the Training, Testing, and Evaluation
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()

# Close the output file
out_file.Close()
