LArAdaBoostDecisionTree.cc
Go to the documentation of this file.
1 /**
2  * @file larpandoracontent/LArObjects/LArAdaBoostDecisionTree.cc
3  *
4  * @brief Implementation of the lar adaptive boost decision tree class.
5  *
6  * $Log: $
7  */
8 
9 #include "Helpers/XmlHelper.h"
10 
12 
13 using namespace pandora;
14 
15 namespace lar_content
16 {
17 
18 AdaBoostDecisionTree::AdaBoostDecisionTree() : m_pStrongClassifier(nullptr)
19 {
20 }
21 
22 //------------------------------------------------------------------------------------------------------------------------------------------
23 
25 {
27 }
28 
29 //------------------------------------------------------------------------------------------------------------------------------------------
30 
32 {
33  if (this != &rhs)
35 
36  return *this;
37 }
38 
39 //------------------------------------------------------------------------------------------------------------------------------------------
40 
42 {
43  delete m_pStrongClassifier;
44 }
45 
46 //------------------------------------------------------------------------------------------------------------------------------------------
47 
48 StatusCode AdaBoostDecisionTree::Initialize(const std::string &bdtXmlFileName, const std::string &bdtName)
49 {
51  {
52  std::cout << "AdaBoostDecisionTree: AdaBoostDecisionTree was already initialized" << std::endl;
53  return STATUS_CODE_ALREADY_INITIALIZED;
54  }
55 
56  TiXmlDocument xmlDocument(bdtXmlFileName);
57 
58  if (!xmlDocument.LoadFile())
59  {
60  std::cout << "AdaBoostDecisionTree::Initialize - Invalid xml file." << std::endl;
61  return STATUS_CODE_INVALID_PARAMETER;
62  }
63 
64  const TiXmlHandle xmlDocumentHandle(&xmlDocument);
65  TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
66 
67  while (pContainerXmlNode)
68  {
69  if (pContainerXmlNode->ValueStr() != "AdaBoostDecisionTree")
70  return STATUS_CODE_FAILURE;
71 
72  const TiXmlHandle currentHandle(pContainerXmlNode);
73 
74  std::string currentName;
75  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle, "Name", currentName));
76 
77  if (currentName.empty() || (currentName.size() > 1000))
78  {
79  std::cout << "AdaBoostDecisionTree::Initialize - Implausible AdaBoostDecisionTree name extracted from xml." << std::endl;
80  return STATUS_CODE_INVALID_PARAMETER;
81  }
82 
83  if (currentName == bdtName)
84  break;
85 
86  pContainerXmlNode = pContainerXmlNode->NextSibling();
87  }
88 
89  if (!pContainerXmlNode)
90  {
91  std::cout << "AdaBoostDecisionTree: Could not find an AdaBoostDecisionTree of name " << bdtName << std::endl;
92  return STATUS_CODE_NOT_FOUND;
93  }
94 
95  const TiXmlHandle xmlHandle(pContainerXmlNode);
96 
97  try
98  {
99  m_pStrongClassifier = new StrongClassifier(&xmlHandle);
100  }
101  catch (StatusCodeException &statusCodeException)
102  {
103  delete m_pStrongClassifier;
104 
105  if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
106  std::cout << "AdaBoostDecisionTree: Initialization failure, unknown component in xml file." << std::endl;
107 
108  if (STATUS_CODE_FAILURE == statusCodeException.GetStatusCode())
109  std::cout << "AdaBoostDecisionTree: Node definition does not contain expected leaf or branch variables." << std::endl;
110 
111  return statusCodeException.GetStatusCode();
112  }
113 
114  return STATUS_CODE_SUCCESS;
115 }
116 
117 //------------------------------------------------------------------------------------------------------------------------------------------
118 
120 {
121  return ((this->CalculateScore(features) > 0.) ? true : false);
122 }
123 
124 //------------------------------------------------------------------------------------------------------------------------------------------
125 
127 {
128  return this->CalculateScore(features);
129 }
130 
131 //------------------------------------------------------------------------------------------------------------------------------------------
132 
134 {
135  // ATTN: BDT score, once normalised by total weight, is confined to the range -1 to +1. This linear mapping places the score in the
136  // range 0 to 1 so that it may be interpreted as a probability.
137  return (this->CalculateScore(features) + 1.) * 0.5;
138 }
139 
140 //------------------------------------------------------------------------------------------------------------------------------------------
141 
143 {
144  if (!m_pStrongClassifier)
145  {
146  std::cout << "AdaBoostDecisionTree: Attempting to use an uninitialized bdt" << std::endl;
147  throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
148  }
149 
150  try
151  {
152  // TODO: Add consistency check for number of features, bearing in mind not all features in a bdt may be used
153  return m_pStrongClassifier->Predict(features);
154  }
155  catch (StatusCodeException &statusCodeException)
156  {
157  if (STATUS_CODE_NOT_FOUND == statusCodeException.GetStatusCode())
158  {
159  std::cout << "AdaBoostDecisionTree: Caught exception thrown when trying to cut on an unknown variable." << std::endl;
160  }
161  else if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
162  {
163  std::cout << "AdaBoostDecisionTree: Caught exception thrown when classifier weights sum to zero indicating defunct classifier."
164  << std::endl;
165  }
166  else if (STATUS_CODE_OUT_OF_RANGE == statusCodeException.GetStatusCode())
167  {
168  std::cout << "AdaBoostDecisionTree: Caught exception thrown when heirarchy in decision tree is incomplete." << std::endl;
169  }
170  else
171  {
172  std::cout << "AdaBoostDecisionTree: Unexpected exception thrown." << std::endl;
173  }
174 
175  throw statusCodeException;
176  }
177 }
178 
179 //------------------------------------------------------------------------------------------------------------------------------------------
180 //------------------------------------------------------------------------------------------------------------------------------------------
181 
182 AdaBoostDecisionTree::Node::Node(const TiXmlHandle *const pXmlHandle) :
183  m_nodeId(0),
184  m_parentNodeId(0),
185  m_leftChildNodeId(0),
186  m_rightChildNodeId(0),
187  m_isLeaf(false),
188  m_threshold(0.),
189  m_variableId(0),
190  m_outcome(false)
191 {
192  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "NodeId", m_nodeId));
193  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "ParentNodeId", m_parentNodeId));
194 
195  const StatusCode leftChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "LeftChildNodeId", m_leftChildNodeId));
196  const StatusCode rightChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "RightChildNodeId", m_rightChildNodeId));
197  const StatusCode thresholdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Threshold", m_threshold));
198  const StatusCode variableIdStatusCode(XmlHelper::ReadValue(*pXmlHandle, "VariableId", m_variableId));
199  const StatusCode outcomeStatusCode(XmlHelper::ReadValue(*pXmlHandle, "Outcome", m_outcome));
200 
201  if (STATUS_CODE_SUCCESS == leftChildNodeIdStatusCode || STATUS_CODE_SUCCESS == rightChildNodeIdStatusCode ||
202  STATUS_CODE_SUCCESS == thresholdStatusCode || STATUS_CODE_SUCCESS == variableIdStatusCode)
203  {
204  m_isLeaf = false;
205  m_outcome = false;
206  }
207  else if (outcomeStatusCode == STATUS_CODE_SUCCESS)
208  {
209  m_isLeaf = true;
210  m_leftChildNodeId = std::numeric_limits<int>::max();
211  m_rightChildNodeId = std::numeric_limits<int>::max();
212  m_threshold = std::numeric_limits<double>::max();
213  m_variableId = std::numeric_limits<int>::max();
214  }
215  else
216  {
217  throw StatusCodeException(STATUS_CODE_FAILURE);
218  }
219 }
220 
221 //------------------------------------------------------------------------------------------------------------------------------------------
222 
224  m_nodeId(rhs.m_nodeId),
225  m_parentNodeId(rhs.m_parentNodeId),
226  m_leftChildNodeId(rhs.m_leftChildNodeId),
227  m_rightChildNodeId(rhs.m_rightChildNodeId),
228  m_isLeaf(rhs.m_isLeaf),
229  m_threshold(rhs.m_threshold),
230  m_variableId(rhs.m_variableId),
231  m_outcome(rhs.m_outcome)
232 {
233 }
234 
235 //------------------------------------------------------------------------------------------------------------------------------------------
236 
238 {
239  if (this != &rhs)
240  {
241  m_nodeId = rhs.m_nodeId;
245  m_isLeaf = rhs.m_isLeaf;
246  m_threshold = rhs.m_threshold;
248  m_outcome = rhs.m_outcome;
249  }
250 
251  return *this;
252 }
253 
254 //------------------------------------------------------------------------------------------------------------------------------------------
255 
257 {
258 }
259 
260 //------------------------------------------------------------------------------------------------------------------------------------------
261 //------------------------------------------------------------------------------------------------------------------------------------------
262 
263 AdaBoostDecisionTree::WeakClassifier::WeakClassifier(const TiXmlHandle *const pXmlHandle) : m_weight(0.), m_treeId(0)
264 {
265  for (TiXmlElement *pHeadTiXmlElement = pXmlHandle->FirstChildElement().ToElement(); pHeadTiXmlElement != NULL;
266  pHeadTiXmlElement = pHeadTiXmlElement->NextSiblingElement())
267  {
268  if ("TreeIndex" == pHeadTiXmlElement->ValueStr())
269  {
270  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeIndex", m_treeId));
271  }
272  else if ("TreeWeight" == pHeadTiXmlElement->ValueStr())
273  {
274  PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle, "TreeWeight", m_weight));
275  }
276  else if ("Node" == pHeadTiXmlElement->ValueStr())
277  {
278  const TiXmlHandle nodeHandle(pHeadTiXmlElement);
279  const Node *pNode = new Node(&nodeHandle);
280  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
281  }
282  }
283 }
284 
285 //------------------------------------------------------------------------------------------------------------------------------------------
286 
287 AdaBoostDecisionTree::WeakClassifier::WeakClassifier(const WeakClassifier &rhs) : m_weight(rhs.m_weight), m_treeId(rhs.m_treeId)
288 {
289  for (const auto &mapEntry : rhs.m_idToNodeMap)
290  {
291  const Node *pNode = new Node(*(mapEntry.second));
292  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
293  }
294 }
295 
296 //------------------------------------------------------------------------------------------------------------------------------------------
297 
299 {
300  if (this != &rhs)
301  {
302  for (const auto &mapEntry : rhs.m_idToNodeMap)
303  {
304  const Node *pNode = new Node(*(mapEntry.second));
305  m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->GetNodeId(), pNode));
306  }
307 
308  m_weight = rhs.m_weight;
309  m_treeId = rhs.m_treeId;
310  }
311 
312  return *this;
313 }
314 
315 //------------------------------------------------------------------------------------------------------------------------------------------
316 
318 {
319  for (const auto &mapEntry : m_idToNodeMap)
320  delete mapEntry.second;
321 }
322 
323 //------------------------------------------------------------------------------------------------------------------------------------------
324 
326 {
327  return this->EvaluateNode(0, features);
328 }
329 
330 //------------------------------------------------------------------------------------------------------------------------------------------
331 
333 {
334  const Node *pActiveNode(nullptr);
335 
336  if (m_idToNodeMap.find(nodeId) != m_idToNodeMap.end())
337  {
338  pActiveNode = m_idToNodeMap.at(nodeId);
339  }
340  else
341  {
342  throw StatusCodeException(STATUS_CODE_OUT_OF_RANGE);
343  }
344 
345  if (pActiveNode->IsLeaf())
346  return pActiveNode->GetOutcome();
347 
348  if (static_cast<int>(features.size()) <= pActiveNode->GetVariableId())
349  throw StatusCodeException(STATUS_CODE_NOT_FOUND);
350 
351  if (features.at(pActiveNode->GetVariableId()).Get() <= pActiveNode->GetThreshold())
352  {
353  return this->EvaluateNode(pActiveNode->GetLeftChildNodeId(), features);
354  }
355  else
356  {
357  return this->EvaluateNode(pActiveNode->GetRightChildNodeId(), features);
358  }
359 }
360 
361 //------------------------------------------------------------------------------------------------------------------------------------------
362 //------------------------------------------------------------------------------------------------------------------------------------------
363 
364 AdaBoostDecisionTree::StrongClassifier::StrongClassifier(const TiXmlHandle *const pXmlHandle)
365 {
366  TiXmlElement *pCurrentXmlElement = pXmlHandle->FirstChild().Element();
367 
368  while (pCurrentXmlElement)
369  {
370  if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
371  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
372 
373  pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
374  }
375 }
376 
377 //------------------------------------------------------------------------------------------------------------------------------------------
378 
380 {
381  for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
382  m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
383 }
384 
385 //------------------------------------------------------------------------------------------------------------------------------------------
386 
388 {
389  if (this != &rhs)
390  {
391  for (const WeakClassifier *const pWeakClassifier : rhs.m_weakClassifiers)
392  m_weakClassifiers.emplace_back(new WeakClassifier(*pWeakClassifier));
393  }
394 
395  return *this;
396 }
397 
398 //------------------------------------------------------------------------------------------------------------------------------------------
399 
401 {
402  for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
403  delete pWeakClassifier;
404 }
405 
406 //------------------------------------------------------------------------------------------------------------------------------------------
407 
409 {
410  double score(0.), weights(0.);
411 
412  for (const WeakClassifier *const pWeakClassifier : m_weakClassifiers)
413  {
414  weights += pWeakClassifier->GetWeight();
415 
416  if (pWeakClassifier->Predict(features))
417  {
418  score += pWeakClassifier->GetWeight();
419  }
420  else
421  {
422  score -= pWeakClassifier->GetWeight();
423  }
424  }
425 
426  if (weights > std::numeric_limits<double>::epsilon())
427  {
428  score /= weights;
429  }
430  else
431  {
432  throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
433  }
434 
435  return score;
436 }
437 
438 //------------------------------------------------------------------------------------------------------------------------------------------
439 
440 StatusCode AdaBoostDecisionTree::StrongClassifier::ReadComponent(TiXmlElement *pCurrentXmlElement)
441 {
442  const std::string componentName(pCurrentXmlElement->ValueStr());
443  TiXmlHandle currentHandle(pCurrentXmlElement);
444 
445  if ((std::string("Name") == componentName) || (std::string("Timestamp") == componentName))
446  return STATUS_CODE_SUCCESS;
447 
448  if (std::string("DecisionTree") == componentName)
449  {
450  m_weakClassifiers.emplace_back(new WeakClassifier(&currentHandle));
451  return STATUS_CODE_SUCCESS;
452  }
453 
454  return STATUS_CODE_INVALID_PARAMETER;
455 }
456 
457 } // namespace lar_content
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
int GetLeftChildNodeId() const
Return left child node id.
int GetVariableId() const
Return cut variable.
MvaTypes::MvaFeatureVector MvaFeatureVector
Definition: LArMvaHelper.h:58
QMapNodeBase Node
Definition: qmap.cpp:41
std::string string
Definition: nybbler.cc:12
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
WeakClassifier class containing a decision tree and a weight.
double m_threshold
Threshold used for decision if decision node.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
const char features[]
Definition: feature_tests.c:2
double GetThreshold() const
Return node threshold.
StrongClassifier class used in application of adaptive boost decision tree.
Header file for the lar adaptive boosted decision tree class.
Node & operator=(const Node &rhs)
Assignment operator.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
static int max(int a, int b)
Node class used for representing a decision tree.
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model...
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetRightChildNodeId() const
Return right child node id.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model...
pandora::StatusCode Initialize(const std::string &parameterLocation, const std::string &bdtName)
Initialize the bdt model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
QTextStream & endl(QTextStream &s)
int m_variableId
Variable cut on for decision if decision node.
bool IsLeaf() const
Return is the node a leaf.