9 #include "Helpers/XmlHelper.h" 18 AdaBoostDecisionTree::AdaBoostDecisionTree() : m_pStrongClassifier(nullptr)
52 std::cout <<
"AdaBoostDecisionTree: AdaBoostDecisionTree was already initialized" <<
std::endl;
53 return STATUS_CODE_ALREADY_INITIALIZED;
56 TiXmlDocument xmlDocument(bdtXmlFileName);
58 if (!xmlDocument.LoadFile())
60 std::cout <<
"AdaBoostDecisionTree::Initialize - Invalid xml file." <<
std::endl;
61 return STATUS_CODE_INVALID_PARAMETER;
64 const TiXmlHandle xmlDocumentHandle(&xmlDocument);
65 TiXmlNode *pContainerXmlNode(TiXmlHandle(xmlDocumentHandle).FirstChildElement().Element());
67 while (pContainerXmlNode)
69 if (pContainerXmlNode->ValueStr() !=
"AdaBoostDecisionTree")
70 return STATUS_CODE_FAILURE;
72 const TiXmlHandle currentHandle(pContainerXmlNode);
75 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(currentHandle,
"Name", currentName));
77 if (currentName.empty() || (currentName.size() > 1000))
79 std::cout <<
"AdaBoostDecisionTree::Initialize - Implausible AdaBoostDecisionTree name extracted from xml." <<
std::endl;
80 return STATUS_CODE_INVALID_PARAMETER;
83 if (currentName == bdtName)
86 pContainerXmlNode = pContainerXmlNode->NextSibling();
89 if (!pContainerXmlNode)
91 std::cout <<
"AdaBoostDecisionTree: Could not find an AdaBoostDecisionTree of name " << bdtName <<
std::endl;
92 return STATUS_CODE_NOT_FOUND;
95 const TiXmlHandle xmlHandle(pContainerXmlNode);
101 catch (StatusCodeException &statusCodeException)
105 if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
106 std::cout <<
"AdaBoostDecisionTree: Initialization failure, unknown component in xml file." <<
std::endl;
108 if (STATUS_CODE_FAILURE == statusCodeException.GetStatusCode())
109 std::cout <<
"AdaBoostDecisionTree: Node definition does not contain expected leaf or branch variables." <<
std::endl;
111 return statusCodeException.GetStatusCode();
114 return STATUS_CODE_SUCCESS;
146 std::cout <<
"AdaBoostDecisionTree: Attempting to use an uninitialized bdt" <<
std::endl;
147 throw StatusCodeException(STATUS_CODE_NOT_INITIALIZED);
155 catch (StatusCodeException &statusCodeException)
157 if (STATUS_CODE_NOT_FOUND == statusCodeException.GetStatusCode())
159 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when trying to cut on an unknown variable." <<
std::endl;
161 else if (STATUS_CODE_INVALID_PARAMETER == statusCodeException.GetStatusCode())
163 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when classifier weights sum to zero indicating defunct classifier." 166 else if (STATUS_CODE_OUT_OF_RANGE == statusCodeException.GetStatusCode())
168 std::cout <<
"AdaBoostDecisionTree: Caught exception thrown when heirarchy in decision tree is incomplete." <<
std::endl;
172 std::cout <<
"AdaBoostDecisionTree: Unexpected exception thrown." <<
std::endl;
175 throw statusCodeException;
185 m_leftChildNodeId(0),
186 m_rightChildNodeId(0),
192 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"NodeId", m_nodeId));
193 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"ParentNodeId", m_parentNodeId));
195 const StatusCode leftChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"LeftChildNodeId", m_leftChildNodeId));
196 const StatusCode rightChildNodeIdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"RightChildNodeId", m_rightChildNodeId));
197 const StatusCode thresholdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"Threshold", m_threshold));
198 const StatusCode variableIdStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"VariableId", m_variableId));
199 const StatusCode outcomeStatusCode(XmlHelper::ReadValue(*pXmlHandle,
"Outcome", m_outcome));
201 if (STATUS_CODE_SUCCESS == leftChildNodeIdStatusCode || STATUS_CODE_SUCCESS == rightChildNodeIdStatusCode ||
202 STATUS_CODE_SUCCESS == thresholdStatusCode || STATUS_CODE_SUCCESS == variableIdStatusCode)
207 else if (outcomeStatusCode == STATUS_CODE_SUCCESS)
217 throw StatusCodeException(STATUS_CODE_FAILURE);
224 m_nodeId(rhs.m_nodeId),
225 m_parentNodeId(rhs.m_parentNodeId),
226 m_leftChildNodeId(rhs.m_leftChildNodeId),
227 m_rightChildNodeId(rhs.m_rightChildNodeId),
228 m_isLeaf(rhs.m_isLeaf),
229 m_threshold(rhs.m_threshold),
230 m_variableId(rhs.m_variableId),
231 m_outcome(rhs.m_outcome)
265 for (TiXmlElement *pHeadTiXmlElement = pXmlHandle->FirstChildElement().ToElement(); pHeadTiXmlElement != NULL;
266 pHeadTiXmlElement = pHeadTiXmlElement->NextSiblingElement())
268 if (
"TreeIndex" == pHeadTiXmlElement->ValueStr())
270 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"TreeIndex", m_treeId));
272 else if (
"TreeWeight" == pHeadTiXmlElement->ValueStr())
274 PANDORA_THROW_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(*pXmlHandle,
"TreeWeight", m_weight));
276 else if (
"Node" == pHeadTiXmlElement->ValueStr())
278 const TiXmlHandle nodeHandle(pHeadTiXmlElement);
279 const Node *pNode =
new Node(&nodeHandle);
280 m_idToNodeMap.insert(IdToNodeMap::value_type(pNode->
GetNodeId(), pNode));
291 const Node *pNode =
new Node(*(mapEntry.second));
304 const Node *pNode =
new Node(*(mapEntry.second));
320 delete mapEntry.second;
334 const Node *pActiveNode(
nullptr);
342 throw StatusCodeException(STATUS_CODE_OUT_OF_RANGE);
345 if (pActiveNode->
IsLeaf())
348 if (static_cast<int>(features.size()) <= pActiveNode->
GetVariableId())
349 throw StatusCodeException(STATUS_CODE_NOT_FOUND);
366 TiXmlElement *pCurrentXmlElement = pXmlHandle->FirstChild().Element();
368 while (pCurrentXmlElement)
370 if (STATUS_CODE_SUCCESS != this->ReadComponent(pCurrentXmlElement))
371 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
373 pCurrentXmlElement = pCurrentXmlElement->NextSiblingElement();
382 m_weakClassifiers.emplace_back(
new WeakClassifier(*pWeakClassifier));
392 m_weakClassifiers.emplace_back(
new WeakClassifier(*pWeakClassifier));
402 for (
const WeakClassifier *
const pWeakClassifier : m_weakClassifiers)
403 delete pWeakClassifier;
410 double score(0.), weights(0.);
412 for (
const WeakClassifier *
const pWeakClassifier : m_weakClassifiers)
414 weights += pWeakClassifier->GetWeight();
416 if (pWeakClassifier->Predict(features))
418 score += pWeakClassifier->GetWeight();
422 score -= pWeakClassifier->GetWeight();
426 if (weights > std::numeric_limits<double>::epsilon())
432 throw StatusCodeException(STATUS_CODE_INVALID_PARAMETER);
442 const std::string componentName(pCurrentXmlElement->ValueStr());
443 TiXmlHandle currentHandle(pCurrentXmlElement);
446 return STATUS_CODE_SUCCESS;
450 m_weakClassifiers.emplace_back(
new WeakClassifier(¤tHandle));
451 return STATUS_CODE_SUCCESS;
454 return STATUS_CODE_INVALID_PARAMETER;
int m_rightChildNodeId
Right child node id.
WeakClassifiers m_weakClassifiers
Vector of weak classifers.
WeakClassifier & operator=(const WeakClassifier &rhs)
Assignment operator.
int GetLeftChildNodeId() const
Return left child node id.
int GetVariableId() const
Return cut variable.
MvaTypes::MvaFeatureVector MvaFeatureVector
IdToNodeMap m_idToNodeMap
Decision tree nodes.
bool GetOutcome() const
Return outcome.
WeakClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
bool Classify(const LArMvaHelper::MvaFeatureVector &features) const
Classify the set of input features based on the trained model.
WeakClassifier class containing a decision tree and a weight.
int m_treeId
Decision tree id.
double m_threshold
Threshold used for decision if decision node.
int GetNodeId() const
Return node id.
double Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
AdaBoostDecisionTree class.
int m_parentNodeId
Parent node id.
double GetThreshold() const
Return node threshold.
bool m_outcome
Outcome if leaf node.
StrongClassifier class used in application of adaptive boost decision tree.
Header file for the lar adaptive boosted decision tree class.
Node & operator=(const Node &rhs)
Assignment operator.
bool EvaluateNode(const int nodeId, const LArMvaHelper::MvaFeatureVector &features) const
Evalute node and return outcome.
int m_leftChildNodeId
Left child node id.
~WeakClassifier()
Destructor.
static int max(int a, int b)
Node class used for representing a decision tree.
Node(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
double CalculateProbability(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification probability for a set of input features, based on the trained model...
pandora::StatusCode ReadComponent(pandora::TiXmlElement *pCurrentXmlElement)
Read xml element and if weak classifier add to member variables.
StrongClassifier(const pandora::TiXmlHandle *const pXmlHandle)
Constructor using xml handle to set member variables.
int GetRightChildNodeId() const
Return right child node id.
~AdaBoostDecisionTree()
Destructor.
AdaBoostDecisionTree()
Constructor.
double m_weight
Boost weight.
double CalculateClassificationScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate the classification score for a set of input features, based on the trained model...
pandora::StatusCode Initialize(const std::string ¶meterLocation, const std::string &bdtName)
Initialize the bdt model.
AdaBoostDecisionTree & operator=(const AdaBoostDecisionTree &rhs)
Assignment operator.
double CalculateScore(const LArMvaHelper::MvaFeatureVector &features) const
Calculate score for input features using strong classifier.
StrongClassifier & operator=(const StrongClassifier &rhs)
Assignment operator.
StrongClassifier * m_pStrongClassifier
Strong adaptive boost tree classifier.
bool Predict(const LArMvaHelper::MvaFeatureVector &features) const
Predict signal or background based on trained data.
QTextStream & endl(QTextStream &s)
int m_variableId
Variable cut on for decision if decision node.
bool IsLeaf() const
Return is the node a leaf.
bool m_isLeaf
Is node a leaf.
~StrongClassifier()
Destructor.