Run network inference.
141 const float eps{1.1920929e-7};
145 PANDORA_MONITORING_API(SetEveDisplayParameters(this->GetPandora(),
true, DETECTOR_VIEW_XZ, -1.
f, 1.
f, 1.
f));
150 const CaloHitList *pCaloHitList(
nullptr);
151 PANDORA_RETURN_RESULT_IF(STATUS_CODE_SUCCESS, !=, PandoraContentApi::GetList(*
this, listName, pCaloHitList));
153 const HitType view{pCaloHitList->front()->GetHitType()};
155 if (!(view == TPC_VIEW_U || view == TPC_VIEW_V || view == TPC_VIEW_W))
156 return STATUS_CODE_NOT_ALLOWED;
165 this->
GetHitRegion(*pCaloHitList, xMin, xMax, zMin, zMax);
166 const float xRange = (xMax + eps) - (xMin - eps);
167 int nTilesX =
static_cast<int>(std::ceil(xRange /
m_tileSize));
171 const int nTiles = sparseMap.size();
173 CaloHitList trackHits, showerHits, otherHits;
179 for (
int i = 0; i < nTiles; ++i)
181 for (
const CaloHit *pCaloHit : *pCaloHitList)
183 const float x(pCaloHit->GetPositionVector().GetX());
184 const float z(pCaloHit->GetPositionVector().GetZ());
186 const int tileX =
static_cast<int>(std::floor((
x - xMin) /
m_tileSize));
187 const int tileZ =
static_cast<int>(std::floor((
z - zMin) /
m_tileSize));
188 const int tile = sparseMap.at(tileZ * nTilesX + tileX);
192 const float localX = std::fmod(
x - xMin,
m_tileSize);
193 const float localZ = std::fmod(
z - zMin,
m_tileSize);
196 const int pixelZ = (m_imageHeight - 1) - static_cast<int>(std::floor(localZ * m_imageHeight /
m_tileSize));
197 weights[pixelZ][pixelX] += pCaloHit->GetInputEnergy();
207 if (weights[
r][
c] > chargeMax)
208 chargeMax = weights[
r][
c];
209 if (weights[
r][
c] < chargeMin)
210 chargeMin = weights[
r][
c];
213 float chargeRange{chargeMax - chargeMin};
214 if (chargeRange <= 0.
f)
221 auto accessor = input.accessor<
float, 4>();
222 for (
const CaloHit *pCaloHit : *pCaloHitList)
224 const float x(pCaloHit->GetPositionVector().GetX());
225 const float z(pCaloHit->GetPositionVector().GetZ());
227 const int tileX =
static_cast<int>(std::floor((
x - xMin) /
m_tileSize));
228 const int tileZ =
static_cast<int>(std::floor((
z - zMin) /
m_tileSize));
229 const int tile = sparseMap.at(tileZ * nTilesX + tileX);
233 const float localX = std::fmod(
x - xMin,
m_tileSize);
234 const float localZ = std::fmod(
z - zMin,
m_tileSize);
236 const int pixelX =
static_cast<int>(std::floor(localX * m_imageWidth /
m_tileSize));
237 const int pixelZ = (m_imageHeight - 1) - static_cast<int>(std::floor(localZ * m_imageHeight /
m_tileSize));
238 accessor[0][0][pixelZ][pixelX] = (weights[pixelZ][pixelX] - chargeMin) / chargeRange;
239 caloHitToPixelMap.insert(std::make_pair(pCaloHit, std::make_tuple(tileZ, tileX, pixelZ, pixelX)));
249 inputs.push_back(input);
252 auto outputAccessor = output.accessor<
float, 4>();
254 for (
const CaloHit *pCaloHit : *pCaloHitList)
256 auto found{caloHitToPixelMap.find(pCaloHit)};
257 if (
found == caloHitToPixelMap.end())
259 auto pixelMap =
found->second;
260 const int tileZ(std::get<0>(pixelMap));
261 const int tileX(std::get<1>(pixelMap));
262 const int tile = sparseMap.at(tileZ * nTilesX + tileX);
265 const int pixelZ(std::get<2>(pixelMap));
266 const int pixelX(std::get<3>(pixelMap));
269 float probShower = exp(outputAccessor[0][1][pixelZ][pixelX]);
270 float probTrack = exp(outputAccessor[0][2][pixelZ][pixelX]);
271 float probNull = exp(outputAccessor[0][0][pixelZ][pixelX]);
272 if (probShower > probTrack && probShower > probNull)
273 showerHits.push_back(pCaloHit);
274 else if (probTrack > probShower && probTrack > probNull)
275 trackHits.push_back(pCaloHit);
277 otherHits.push_back(pCaloHit);
278 float recipSum = 1.f / (probShower + probTrack);
280 probShower *= recipSum;
281 probTrack *= recipSum;
284 pLArCaloHit->SetTrackProbability(probTrack);
294 const std::string trackListName(
"TrackHits_" + listName);
295 const std::string showerListName(
"ShowerHits_" + listName);
296 const std::string otherListName(
"OtherHits_" + listName);
297 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &trackHits, trackListName,
BLUE));
298 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &showerHits, showerListName, RED));
299 PANDORA_MONITORING_API(VisualizeCaloHits(this->GetPandora(), &otherHits, otherListName, BLACK));
305 PANDORA_MONITORING_API(ViewEvent(this->GetPandora()));
308 return STATUS_CODE_SUCCESS;
LArDLHelper::TorchModel m_modelW
Model for the W view.
bool m_visualize
Whether to visualize the track shower ID scores.
void SetShowerProbability(const float probability)
Set the probability that the hit is shower-like.
LArDLHelper::TorchModel m_modelU
Model for the U view.
torch::jit::script::Module TorchModel
pandora::StringVector m_caloHitListNames
Name of input calo hit list.
LArDLHelper::TorchModel m_modelV
Model for the V view.
void GetHitRegion(const pandora::CaloHitList &caloHitList, float &xMin, float &xMax, float &zMin, float &zMax)
Identify the XZ range containing the hits for an event.
static int max(int a, int b)
static void Forward(TorchModel &model, const TorchInputVector &input, TorchOutput &output)
Run a deep learning model.
int m_imageWidth
Width of images in pixels.
int m_imageHeight
Height of images in pixels.
float m_tileSize
Size of tile in cm.
std::map< int, int > PixelToTileMap
void GetSparseTileMap(const pandora::CaloHitList &caloHitList, const float xMin, const float zMin, const int nTilesX, PixelToTileMap &sparseMap)
Populate a map between pixels and tiles.
static void InitialiseInput(const at::IntArrayRef dimensions, TorchInput &tensor)
Create a torch input tensor.
std::map< const pandora::CaloHit *, std::tuple< int, int, int, int > > CaloHitToPixelMap
std::vector< torch::jit::IValue > TorchInputVector