Skip to content

Commit bd3c8d1

Browse files
committed
Merging dev and adjusting build issues
1 parent a23fdc9 commit bd3c8d1

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

GPU/GPUTracking/Global/GPUChainTrackingClusterizer.cxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -931,18 +931,18 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
931931

932932
auto stop0 = std::chrono::high_resolution_clock::now();
933933
auto start1 = std::chrono::high_resolution_clock::now();
934-
nnApplication.inferenceNetwork(clustererNN.model_class, clustererNN, iSize, clusterer.modelProbabilities);
934+
nnApplication.networkInference(nnApplication.model_class, clustererNN, iSize, clustererNN.modelProbabilities, evalDtype);
935935
if (nnApplication.model_class.getNumOutputNodes()[0][1] == 1) {
936936
runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::determineClass1Labels>({GetGrid(iSize, lane, GPUReconstruction::krnlDeviceType::CPU), {iSector}}, iSector, evalDtype, 0, batchStart); // Assigning class labels
937937
} else {
938938
runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::determineClass2Labels>({GetGrid(iSize, lane, GPUReconstruction::krnlDeviceType::CPU), {iSector}}, iSector, evalDtype, 0, batchStart); // Assigning class labels
939939
}
940940

941941
if (!clustererNN.nnClusterizerUseCfRegression) {
942-
nnApplication.inferenceNetwork(clustererNN.model_reg_1, clustererNN, iSize, clusterer.outputDataReg1);
942+
nnApplication.networkInference(nnApplication.model_reg_1, clustererNN, iSize, clustererNN.outputDataReg1, evalDtype);
943943
runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::publishClass1Regression>({GetGrid(iSize, lane, GPUReconstruction::krnlDeviceType::CPU), {iSector}}, iSector, evalDtype, 0, batchStart); // Running the NN for regression class 1
944944
if (nnApplication.model_class.getNumOutputNodes()[0][1] > 1 && nnApplication.reg_model_paths.size() > 1) {
945-
nnApplication.inferenceNetwork(clustererNN.model_reg_2, clustererNN, iSize, clusterer.outputDataReg2);
945+
nnApplication.networkInference(nnApplication.model_reg_2, clustererNN, iSize, clustererNN.outputDataReg2, evalDtype);
946946
runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::publishClass2Regression>({GetGrid(iSize, lane, GPUReconstruction::krnlDeviceType::CPU), {iSector}}, iSector, evalDtype, 0, batchStart); // Running the NN for regression class 2
947947
}
948948
}

GPU/GPUTracking/TPCClusterFinder/GPUTPCNNClusterizerHost.cxx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ GPUTPCNNClusterizerHost::GPUTPCNNClusterizerHost(const GPUSettingsProcessingNNcl
5454
}
5555
}
5656

57-
void GPUTPCNNClusterizerHost::networkInference(o2::ml::OrtModel model, GPUTPCNNClusterizer& clusterer, size_t size, float* output)
57+
void GPUTPCNNClusterizerHost::networkInference(o2::ml::OrtModel model, GPUTPCNNClusterizer& clusterer, size_t size, float* output, int32_t dtype)
5858
{
5959
if (dtype == 0) {
6060
model.inference<OrtDataType::Float16_t, float>(clusterer.inputData16, size * clusterer.nnClusterizerElementSize, output);

GPU/GPUTracking/TPCClusterFinder/GPUTPCNNClusterizerHost.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class GPUTPCNNClusterizerHost
3939
GPUTPCNNClusterizerHost() = default;
4040
GPUTPCNNClusterizerHost(const GPUSettingsProcessingNNclusterizer&, GPUTPCNNClusterizer&);
4141

42-
void networkInference(o2::ml::OrtModel model, GPUTPCNNClusterizer& clusterer, size_t size, float* output);
42+
void networkInference(o2::ml::OrtModel model, GPUTPCNNClusterizer& clusterer, size_t size, float* output, int32_t dtype);
4343

4444
std::unordered_map<std::string, std::string> OrtOptions;
4545
o2::ml::OrtModel model_class, model_reg_1, model_reg_2; // For splitting clusters

GPU/GPUTracking/TPCClusterFinder/GPUTPCNNClusterizerKernels.cxx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ using namespace o2::gpu::tpccf;
2222
#include "CfUtils.h"
2323
#include "ClusterAccumulator.h"
2424
#include "ML/3rdparty/GPUORTFloat16.h"
25+
26+
#if !defined(GPUCA_GPUCODE)
27+
#include "GPUHostDataTypes.h"
28+
#include "MCLabelAccumulator.h"
29+
#endif
30+
2531
#ifdef GPUCA_GPUCODE
2632
#include "GPUTPCCFClusterizer.inc"
2733
#endif

0 commit comments

Comments
 (0)