9 #include "Pandora/AlgorithmHeaders.h" 11 #include <torch/script.h> 12 #include <torch/torch.h> 32 DlHitTrackShowerIdAlgorithm::DlHitTrackShowerIdAlgorithm() :
37 m_useTrainingMode(false),
38 m_trainingOutputFile(
"")
62 const int SHOWER{1}, TRACK{2};
65 const CaloHitList *pCaloHitList(
nullptr);
66 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*
this, listName, pCaloHitList));
67 const MCParticleList *pMCParticleList(
nullptr);
68 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetCurrentList(*
this, pMCParticleList));
70 const HitType view{pCaloHitList->front()->GetHitType()};
72 if (!(view == TPC_VIEW_U || view == TPC_VIEW_V || view == TPC_VIEW_W))
73 return STATUS_CODE_NOT_ALLOWED;
77 if (view == TPC_VIEW_U)
78 trainingOutputFileName +=
"_CaloHitListU.csv";
79 else if (view == TPC_VIEW_V)
80 trainingOutputFileName +=
"_CaloHitListV.csv";
81 else if (view == TPC_VIEW_W)
82 trainingOutputFileName +=
"_CaloHitListW.csv";
90 LArMCParticleHelper::SelectReconstructableMCParticles(
91 pMCParticleList, pCaloHitList, parameters, LArMCParticleHelper::IsBeamNeutrinoFinalState, targetMCParticleToHitsMap);
94 for (
const CaloHit *pCaloHit : *pCaloHitList)
97 float inputEnergy{0.f};
101 const MCParticle *
const pMCParticle(MCParticleHelper::GetMainMCParticle(pCaloHit));
103 if (targetMCParticleToHitsMap.find(pMCParticle) == targetMCParticleToHitsMap.end())
105 if (LArMCParticleHelper::IsDescendentOf(pMCParticle, 2112))
107 inputEnergy = pCaloHit->GetInputEnergy();
108 if (inputEnergy < 0.
f)
111 const int pdg{
std::abs(pMCParticle->GetParticleId())};
112 if (
pdg == 11 ||
pdg == 22)
117 catch (
const StatusCodeException &)
122 featureVector.push_back(static_cast<double>(pCaloHit->GetPositionVector().GetX()));
123 featureVector.push_back(static_cast<double>(pCaloHit->GetPositionVector().GetZ()));
124 featureVector.push_back(static_cast<double>(
tag));
125 featureVector.push_back(static_cast<double>(inputEnergy));
128 featureVector.push_back(static_cast<double>(featureVector.size() / 4));
129 std::rotate(featureVector.rbegin(), featureVector.rbegin() + 1, featureVector.rend());
131 PANDORA_RETURN_RESULT_IF(pandora::STATUS_CODE_SUCCESS, !=, LArMvaHelper::ProduceTrainingExample(trainingOutputFileName,
true, featureVector));
134 return STATUS_CODE_SUCCESS;
141 const float eps{1.1920929e-7};
145 PANDORA_MONITORING_API(SetEveDisplayParameters(this->GetPandora(),
true, DETECTOR_VIEW_XZ, -1.
f, 1.
f, 1.
f));
150 const CaloHitList *pCaloHitList(
nullptr);
151 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*
this, listName, pCaloHitList));
153 const HitType view{pCaloHitList->front()->GetHitType()};
155 if (!(view == TPC_VIEW_U || view == TPC_VIEW_V || view == TPC_VIEW_W))
156 return STATUS_CODE_NOT_ALLOWED;
165 this->
GetHitRegion(*pCaloHitList, xMin, xMax, zMin, zMax);
166 const float xRange = (xMax + eps) - (xMin - eps);
167 int nTilesX =
static_cast<int>(std::ceil(xRange /
m_tileSize));
171 const int nTiles = sparseMap.size();
173 CaloHitList trackHits, showerHits, otherHits;
179 for (
int i = 0; i < nTiles; ++i)
181 for (
const CaloHit *pCaloHit : *pCaloHitList)
183 const float x(pCaloHit->GetPositionVector().GetX());
184 const float z(pCaloHit->GetPositionVector().GetZ());
186 const int tileX =
static_cast<int>(std::floor((
x - xMin) /
m_tileSize));
187 const int tileZ =
static_cast<int>(std::floor((
z - zMin) /
m_tileSize));
188 const int tile = sparseMap.at(tileZ * nTilesX + tileX);
192 const float localX = std::fmod(
x - xMin,
m_tileSize);
193 const float localZ = std::fmod(
z - zMin,
m_tileSize);
196 const int pixelZ = (m_imageHeight - 1) - static_cast<int>(std::floor(localZ * m_imageHeight /
m_tileSize));
197 weights[pixelZ][pixelX] += pCaloHit->GetInputEnergy();
207 if (weights[
r][
c] > chargeMax)
208 chargeMax = weights[
r][
c];
209 if (weights[
r][
c] < chargeMin)
210 chargeMin = weights[
r][
c];
213 float chargeRange{chargeMax - chargeMin};
214 if (chargeRange <= 0.
f)
221 auto accessor = input.accessor<
float, 4>();
222 for (
const CaloHit *pCaloHit : *pCaloHitList)
224 const float x(pCaloHit->GetPositionVector().GetX());
225 const float z(pCaloHit->GetPositionVector().GetZ());
227 const int tileX =
static_cast<int>(std::floor((
x - xMin) /
m_tileSize));
228 const int tileZ =
static_cast<int>(std::floor((
z - zMin) /
m_tileSize));
229 const int tile = sparseMap.at(tileZ * nTilesX + tileX);
233 const float localX = std::fmod(
x - xMin,
m_tileSize);
234 const float localZ = std::fmod(
z - zMin,
m_tileSize);
236 const int pixelX =
static_cast<int>(std::floor(localX * m_imageWidth /
m_tileSize));
237 const int pixelZ = (m_imageHeight - 1) - static_cast<int>(std::floor(localZ * m_imageHeight /
m_tileSize));
238 accessor[0][0][pixelZ][pixelX] = (weights[pixelZ][pixelX] - chargeMin) / chargeRange;
239 caloHitToPixelMap.insert(std::make_pair(pCaloHit, std::make_tuple(tileZ, tileX, pixelZ, pixelX)));
249 inputs.push_back(input);
252 auto outputAccessor = output.accessor<
float, 4>();
254 for (
const CaloHit *pCaloHit : *pCaloHitList)
256 auto found{caloHitToPixelMap.find(pCaloHit)};
257 if (
found == caloHitToPixelMap.end())
259 auto pixelMap =
found->second;
260 const int tileZ(std::get<0>(pixelMap));
261 const int tileX(std::get<1>(pixelMap));
262 const int tile = sparseMap.at(tileZ * nTilesX + tileX);
265 const int pixelZ(std::get<2>(pixelMap));
266 const int pixelX(std::get<3>(pixelMap));
269 float probShower = exp(outputAccessor[0][1][pixelZ][pixelX]);
270 float probTrack = exp(outputAccessor[0][2][pixelZ][pixelX]);
271 float probNull = exp(outputAccessor[0][0][pixelZ][pixelX]);
272 if (probShower > probTrack && probShower > probNull)
273 showerHits.push_back(pCaloHit);
274 else if (probTrack > probShower && probTrack > probNull)
275 trackHits.push_back(pCaloHit);
277 otherHits.push_back(pCaloHit);
278 float recipSum = 1.f / (probShower + probTrack);
280 probShower *= recipSum;
281 probTrack *= recipSum;
284 pLArCaloHit->SetTrackProbability(probTrack);
294 const std::string trackListName(
"TrackHits_" + listName);
295 const std::string showerListName(
"ShowerHits_" + listName);
296 const std::string otherListName(
"OtherHits_" + listName);
297 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &trackHits, trackListName,
BLUE));
298 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &showerHits, showerListName, RED));
299 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &otherHits, otherListName, BLACK));
305 PANDORA_MONITORING_API(ViewEvent(this->GetPandora()));
308 return STATUS_CODE_SUCCESS;
319 for (
const CaloHit *pCaloHit : caloHitList)
321 const float x(pCaloHit->GetPositionVector().GetX());
322 const float z(pCaloHit->GetPositionVector().GetZ());
337 const CaloHitList &caloHitList,
const float xMin,
const float zMin,
const int nTilesX,
PixelToTileMap &sparseMap)
340 std::map<int, bool> tilePopulationMap;
341 for (
const CaloHit *pCaloHit : caloHitList)
343 const float x(pCaloHit->GetPositionVector().GetX());
344 const float z(pCaloHit->GetPositionVector().GetZ());
346 const int tileX =
static_cast<int>(std::floor((
x - xMin) /
m_tileSize));
347 const int tileZ =
static_cast<int>(std::floor((
z - zMin) /
m_tileSize));
348 const int tile = tileZ * nTilesX + tileX;
349 tilePopulationMap.insert(std::make_pair(tile,
true));
353 for (
auto element : tilePopulationMap)
357 sparseMap.insert(std::make_pair(element.first, nextTile));
367 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"UseTrainingMode",
m_useTrainingMode));
371 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, XmlHelper::ReadValue(xmlHandle,
"TrainingOutputFileName",
m_trainingOutputFile));
375 bool modelLoaded{
false};
376 PANDORA_RETURN_RESULT_IF_AND_IF(
377 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ModelFileNameU",
m_modelFileNameU));
384 PANDORA_RETURN_RESULT_IF_AND_IF(
385 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ModelFileNameV",
m_modelFileNameV));
392 PANDORA_RETURN_RESULT_IF_AND_IF(
393 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ModelFileNameW",
m_modelFileNameW));
402 std::cout <<
"Error: Inference requested, but no model files were successfully loaded" <<
std::endl;
403 return STATUS_CODE_INVALID_PARAMETER;
407 PANDORA_RETURN_RESULT_IF_AND_IF(
408 STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadVectorOfValues(xmlHandle,
"CaloHitListNames",
m_caloHitListNames));
409 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ImageHeight",
m_imageHeight));
410 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"ImageWidth",
m_imageWidth));
411 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"TileSize",
m_tileSize));
414 std::cout <<
"Error: Invalid image size specification" <<
std::endl;
415 return STATUS_CODE_INVALID_PARAMETER;
417 PANDORA_RETURN_RESULT_IF_AND_IF(STATUS_CODE_SUCCESS, STATUS_CODE_NOT_FOUND, !=, XmlHelper::ReadValue(xmlHandle,
"Visualize",
m_visualize));
419 return STATUS_CODE_SUCCESS;
Header file for the pfo helper class.
std::unordered_map< const pandora::MCParticle *, pandora::CaloHitList > MCContributionMap
LArDLHelper::TorchModel m_modelW
Model for the W view.
bool m_visualize
Whether to visualize the track shower ID scores.
MvaTypes::MvaFeatureVector MvaFeatureVector
Header file for the lar calo hit class.
void SetShowerProbability(const float probability)
Set the probability that the hit is shower-like.
LArDLHelper::TorchModel m_modelU
Model for the U view.
torch::jit::script::Module TorchModel
std::string m_modelFileNameW
Model file name for W view.
Header file for the lar monitoring helper helper class.
unsigned int m_minHitsForGoodView
the minimum number of Hits for a good view
float m_maxPhotonPropagation
the maximum photon propagation length
std::string m_modelFileNameU
Model file name for U view.
pandora::StringVector m_caloHitListNames
Name of input calo hit list.
Header file for the lar monte carlo particle helper helper class.
pandora::StatusCode Infer()
Run network inference.
std::string m_trainingOutputFile
Output file name for training examples.
LArDLHelper::TorchModel m_modelV
Model for the V view.
void GetHitRegion(const pandora::CaloHitList &caloHitList, float &xMin, float &xMax, float &zMin, float &zMax)
Identify the XZ range containing the hits for an event.
Header file for the file helper class.
static int max(int a, int b)
std::string m_modelFileNameV
Model file name for V view.
static void Forward(TorchModel &model, const TorchInputVector &input, TorchOutput &output)
Run a deep learning model.
pandora::StatusCode Run()
pandora::StatusCode ReadSettings(const pandora::TiXmlHandle xmlHandle)
int m_imageWidth
Width of images in pixels.
virtual ~DlHitTrackShowerIdAlgorithm()
int m_imageHeight
Height of images in pixels.
float m_tileSize
Size of tile in cm.
static pandora::StatusCode LoadModel(const std::string &filename, TorchModel &model)
Loads a deep learning model.
Header file for the deep learning track shower id algorithm.
std::map< int, int > PixelToTileMap
pandora::StatusCode Train()
Produce files that act as inputs to network training.
bool m_useTrainingMode
Training mode.
void GetSparseTileMap(const pandora::CaloHitList &caloHitList, const float xMin, const float zMin, const int nTilesX, PixelToTileMap &sparseMap)
Populate a map between pixels and tiles.
static void InitialiseInput(const at::IntArrayRef dimensions, TorchInput &tensor)
Create a torch input tensor.
std::map< const pandora::CaloHit *, std::tuple< int, int, int, int > > CaloHitToPixelMap
std::vector< torch::jit::IValue > TorchInputVector
QTextStream & endl(QTextStream &s)