dumpTree.py
Go to the documentation of this file.
1 #! /usr/bin/env python
2 #
3 # Read almost every field in the event tree.
4 #
5 
6 from math import sqrt
7 
8 import numpy as np
9 import fire
10 import h5py
11 
12 from ROOT import TG4Event, TFile
13 
14 # Print the fields in a TG4PrimaryParticle object
15 def printPrimaryParticle(depth, primaryParticle):
16  print(depth,"Class: ", primaryParticle.ClassName())
17  print(depth,"Track Id:", primaryParticle.GetTrackId())
18  print(depth,"Name:", primaryParticle.GetName())
19  print(depth,"PDG Code:",primaryParticle.GetPDGCode())
20  print(depth,"Momentum:",primaryParticle.GetMomentum().X(),
21  primaryParticle.GetMomentum().Y(),
22  primaryParticle.GetMomentum().Z(),
23  primaryParticle.GetMomentum().E(),
24  primaryParticle.GetMomentum().P(),
25  primaryParticle.GetMomentum().M())
26 
27 # Print the fields in an TG4PrimaryVertex object
28 def printPrimaryVertex(depth, primaryVertex):
29  print(depth,"Class: ", primaryVertex.ClassName())
30  print(depth,"Position:", primaryVertex.GetPosition().X(),
31  primaryVertex.GetPosition().Y(),
32  primaryVertex.GetPosition().Z(),
33  primaryVertex.GetPosition().T())
34  print(depth,"Generator:",primaryVertex.GetGeneratorName())
35  print(depth,"Reaction:",primaryVertex.GetReaction())
36  print(depth,"Filename:",primaryVertex.GetFilename())
37  print(depth,"InteractionNumber:",primaryVertex.GetInteractionNumber())
38  depth = depth + ".."
39  for infoVertex in primaryVertex.Informational:
40  printPrimaryVertex(depth,infoVertex)
41  for primaryParticle in primaryVertex.Particles:
42  printPrimaryParticle(depth,primaryParticle)
43 
44 # Print the fields in a TG4TrajectoryPoint object
45 def printTrajectoryPoint(depth, trajectoryPoint):
46  print(depth,"Class: ", trajectoryPoint.ClassName())
47  print(depth,"Position:", trajectoryPoint.GetPosition().X(),
48  trajectoryPoint.GetPosition().Y(),
49  trajectoryPoint.GetPosition().Z(),
50  trajectoryPoint.GetPosition().T())
51  print(depth,"Momentum:", trajectoryPoint.GetMomentum().X(),
52  trajectoryPoint.GetMomentum().Y(),
53  trajectoryPoint.GetMomentum().Z(),
54  trajectoryPoint.GetMomentum().Mag())
55  print(depth,"Process",trajectoryPoint.GetProcess())
56  print(depth,"Subprocess",trajectoryPoint.GetSubprocess())
57 
58 # Print the fields in a TG4Trajectory object
59 def printTrajectory(depth, trajectory):
60  print(depth,"Class: ", trajectory.ClassName())
61  depth = depth + ".."
62  print(depth,"Track Id/Parent Id:",
63  trajectory.GetTrackId(),
64  trajectory.GetParentId())
65  print(depth,"Name:",trajectory.GetName())
66  print(depth,"PDG Code",trajectory.GetPDGCode())
67  print(depth,"Initial Momentum:",trajectory.GetInitialMomentum().X(),
68  trajectory.GetInitialMomentum().Y(),
69  trajectory.GetInitialMomentum().Z(),
70  trajectory.GetInitialMomentum().E(),
71  trajectory.GetInitialMomentum().P(),
72  trajectory.GetInitialMomentum().M())
73  for trajectoryPoint in trajectory.Points:
74  printTrajectoryPoint(depth,trajectoryPoint)
75 
76 # Print the fields in a TG4HitSegment object
77 def printHitSegment(depth, hitSegment):
78  print(depth,"Class: ", hitSegment.ClassName())
79  print(depth,"Primary Id:", hitSegment.GetPrimaryId())
80  print(depth,"Energy Deposit:",hitSegment.GetEnergyDeposit())
81  print(depth,"Secondary Deposit:", hitSegment.GetSecondaryDeposit())
82  print(depth,"Track Length:",hitSegment.GetTrackLength())
83  print(depth,"Start:", hitSegment.GetStart().X(),
84  hitSegment.GetStart().Y(),
85  hitSegment.GetStart().Z(),
86  hitSegment.GetStart().T())
87  print(depth,"Stop:", hitSegment.GetStop().X(),
88  hitSegment.GetStop().Y(),
89  hitSegment.GetStop().Z(),
90  hitSegment.GetStop().T())
91  print(depth,"Contributor:", [contributor for contributor in hitSegment.Contrib])
92 
93 # Print the fields in a single element of the SegmentDetectors map.
94 # The container name is the key, and the hitSegments is the value (a
95 # vector of TG4HitSegment objects).
96 def printSegmentContainer(depth, containerName, hitSegments):
97  print(depth,"Detector: ", containerName, hitSegments.size())
98  depth = depth + ".."
99  for hitSegment in hitSegments: printHitSegment(depth, hitSegment)
100 
101 # Read a file and dump it.
102 def dump(input_file, output_file):
103 
104  # The input file is generated in a previous test (100TestTree.sh).
105  inputFile = TFile(input_file)
106 
107  # Get the input tree out of the file.
108  inputTree = inputFile.Get("EDepSimEvents")
109  print("Class:", inputTree.ClassName())
110 
111  # Attach a brach to the events.
112  event = TG4Event()
113  inputTree.SetBranchAddress("Event",event)
114 
115  # Read all of the events.
116  entries = inputTree.GetEntriesFast()
117 
118  segments_dtype = np.dtype([('eventID', 'u4'), ('z_end', 'f4'),
119  ('trackID', 'u4'), ('tran_diff', 'f4'),
120  ('z_start', 'f4'), ('x_end', 'f4'),
121  ('y_end', 'f4'), ('n_electrons', 'u4'),
122  ('pdgId', 'i4'), ('x_start', 'f4'),
123  ('y_start', 'f4'), ('t_start', 'f4'),
124  ('dx', 'f4'), ('long_diff', 'f4'),
125  ('pixel_plane', 'i4'), ('t_end', 'f4'),
126  ('dEdx', 'f4'), ('dE', 'f4'), ('t', 'f4'),
127  ('y', 'f4'), ('x', 'f4'), ('z', 'f4'),
128  ('n_photons','f4')])
129 
130  trajectories_dtype = np.dtype([('eventID', 'u4'), ('trackID', 'u4'),
131  ('parentID', 'i4'),
132  ('pxyz_start', 'f4', (3,)),
133  ('xyz_start', 'f4', (3,)), ('t_start', 'f4'),
134  ('pxyz_end', 'f4', (3,)),
135  ('xyz_end', 'f4', (3,)), ('t_end', 'f4'),
136  ('pdgId', 'i4'), ('start_process', 'u4'),
137  ('start_subprocess', 'u4'),
138  ('end_process', 'u4'),
139  ('end_subprocess', 'u4')])
140 
141  segments_list = []
142  trajectories_list = []
143 
144  for jentry in range(entries):
145  print(jentry)
146  nb = inputTree.GetEntry(jentry)
147  if nb <= 0:
148  continue
149 
150  print("Class: ", event.ClassName())
151  print("Event number:", event.EventId)
152 
153  # Dump the primary vertices
154  # for primaryVertex in event.Primaries:
155  # printPrimaryVertex("PP", primaryVertex)
156 
157  # Dump the trajectories
158  print("Number of trajectories ", len(event.Trajectories))
159  trajectories = np.empty(len(event.Trajectories), dtype=trajectories_dtype)
160  for iTraj, trajectory in enumerate(event.Trajectories):
161  start_pt, end_pt = trajectory.Points[0], trajectory.Points[-1]
162  trajectories[iTraj]['eventID'] = jentry
163  trajectories[iTraj]['trackID'] = trajectory.GetTrackId()
164  trajectories[iTraj]['parentID'] = trajectory.GetParentId()
165  trajectories[iTraj]['pxyz_start'] = (start_pt.GetMomentum().X(), start_pt.GetMomentum().Y(), start_pt.GetMomentum().Z())
166  trajectories[iTraj]['pxyz_end'] = (end_pt.GetMomentum().X(), end_pt.GetMomentum().Y(), end_pt.GetMomentum().Z())
167  trajectories[iTraj]['xyz_start'] = (start_pt.GetPosition().X(), start_pt.GetPosition().Y(), start_pt.GetPosition().Z())
168  trajectories[iTraj]['xyz_end'] = (end_pt.GetPosition().X(), end_pt.GetPosition().Y(), end_pt.GetPosition().Z())
169  trajectories[iTraj]['t_start'] = start_pt.GetPosition().T()
170  trajectories[iTraj]['t_end'] = end_pt.GetPosition().T()
171  trajectories[iTraj]['start_process'] = start_pt.GetProcess()
172  trajectories[iTraj]['start_subprocess'] = start_pt.GetSubprocess()
173  trajectories[iTraj]['end_process'] = end_pt.GetProcess()
174  trajectories[iTraj]['end_subprocess'] = end_pt.GetSubprocess()
175  trajectories[iTraj]['pdgId'] = trajectory.GetPDGCode()
176  trajectories_list.append(trajectories)
177 
178  # Dump the segment containers
179  print("Number of segment containers:", event.SegmentDetectors.size())
180 
181  for containerName, hitSegments in event.SegmentDetectors:
182 
183  segment = np.empty(len(hitSegments), dtype=segments_dtype)
184  for iHit, hitSegment in enumerate(hitSegments):
185  segment[iHit]['eventID'] = jentry
186  segment[iHit]['trackID'] = trajectories[hitSegment.Contrib[0]]['trackID']
187  segment[iHit]['x_start'] = hitSegment.GetStart().X() / 10
188  segment[iHit]['y_start'] = hitSegment.GetStart().Y() / 10
189  segment[iHit]['z_start'] = hitSegment.GetStart().Z() / 10
190  segment[iHit]['x_end'] = hitSegment.GetStop().X() / 10
191  segment[iHit]['y_end'] = hitSegment.GetStop().Y() / 10
192  segment[iHit]['z_end'] = hitSegment.GetStop().Z() / 10
193  segment[iHit]['dE'] = hitSegment.GetEnergyDeposit()
194  segment[iHit]['t'] = 0
195  segment[iHit]['t_start'] = 0
196  segment[iHit]['t_end'] = 0
197  xd = segment[iHit]['x_end'] - segment[iHit]['x_start']
198  yd = segment[iHit]['y_end'] - segment[iHit]['y_start']
199  zd = segment[iHit]['z_end'] - segment[iHit]['z_start']
200  dx = sqrt(xd**2 + yd**2 + zd**2)
201  segment[iHit]['dx'] = dx
202  segment[iHit]['x'] = (segment[iHit]['x_start'] + segment[iHit]['x_end']) / 2.
203  segment[iHit]['y'] = (segment[iHit]['y_start'] + segment[iHit]['y_end']) / 2.
204  segment[iHit]['z'] = (segment[iHit]['z_start'] + segment[iHit]['z_end']) / 2.
205  segment[iHit]['dEdx'] = hitSegment.GetEnergyDeposit() / dx if dx > 0 else 0
206  segment[iHit]['pdgId'] = trajectories[hitSegment.Contrib[0]]['pdgId']
207  segment[iHit]['n_electrons'] = 0
208  segment[iHit]['long_diff'] = 0
209  segment[iHit]['tran_diff'] = 0
210  segment[iHit]['pixel_plane'] = 0
211  segment[iHit]['n_photons'] = 0
212  segments_list.append(segment)
213  trajectories_list = np.concatenate(trajectories_list, axis=0)
214  segments_list = np.concatenate(segments_list, axis=0)
215 
216  with h5py.File(output_file, 'w') as f:
217  f.create_dataset("trajectories", data=trajectories_list)
218  f.create_dataset("segments", data=segments_list)
219 
220 
221 if __name__ == "__main__":
222  fire.Fire(dump)
def printHitSegment(depth, hitSegment)
Definition: dumpTree.py:77
def printPrimaryVertex(depth, primaryVertex)
Definition: dumpTree.py:28
std::pair< float, std::string > P
def printTrajectory(depth, trajectory)
Definition: dumpTree.py:59
auto enumerate(Iterables &&...iterables)
Range-for loop helper tracking the number of iteration.
Definition: enumerate.h:69
def printTrajectoryPoint(depth, trajectoryPoint)
Definition: dumpTree.py:45
def dump(input_file, output_file)
Definition: dumpTree.py:102
def printPrimaryParticle(depth, primaryParticle)
Definition: dumpTree.py:15
def printSegmentContainer(depth, containerName, hitSegments)
Definition: dumpTree.py:96