@@ -943,9 +943,6 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
943943 GPUTPCNNClusterizer& clustererNNShadow = doGPU ? processorsShadow ()->tpcNNClusterer [lane] : clustererNN;
944944 GPUTPCNNClusterizerHost& nnApplication = nnApplications[lane];
945945
946- LOG (info) << " clustererNNShadow.inputData32: " << clustererNNShadow.inputData32 ;
947- LOG (info) << " clustererShadow.mPclusterInRow: " << clustererShadow.mPclusterInRow ;
948-
949946 int withMC = (doGPU && propagateMCLabels);
950947
951948 if (clustererNNShadow.nnClusterizerUseCfRegression || (int )(nn_settings.nnClusterizerApplyCfDeconvolution )) {
@@ -963,19 +960,58 @@ int32_t GPUChainTracking::RunTPCClusterizer(bool synchronizeOutput)
963960
964961 auto stop0 = std::chrono::high_resolution_clock::now ();
965962 auto start1 = std::chrono::high_resolution_clock::now ();
966- nnApplication.networkInference (nnApplication.model_class , clustererNNShadow, iSize, clustererNNShadow.modelProbabilities , clustererNNShadow.nnInferenceInputDType );
963+
964+ // nnApplication.networkInference(nnApplication.model_class, clustererNNShadow, iSize, clustererNNShadow.modelProbabilities, clustererNNShadow.nnInferenceInputDType);
965+ if (clustererNNShadow.nnInferenceInputDType == 0 ) {
966+ if (clustererNNShadow.nnInferenceOutputDType == 0 ) {
967+ (nnApplication.model_class ).inference (clustererNNShadow.inputData_16 , iSize, clustererNNShadow.modelProbabilities_16 );
968+ } else if (clustererNNShadow.nnInferenceOutputDType == 1 ) {
969+ (nnApplication.model_class ).inference (clustererNNShadow.inputData_16 , iSize, clustererNNShadow.modelProbabilities_32 );
970+ }
971+ } else if (clustererNNShadow.nnInferenceInputDType == 1 ) {
972+ if (clustererNNShadow.nnInferenceOutputDType == 0 ) {
973+ (nnApplication.model_class ).inference (clustererNNShadow.inputData_32 , iSize, clustererNNShadow.modelProbabilities_16 );
974+ } else if (clustererNNShadow.nnInferenceOutputDType == 1 ) {
975+ (nnApplication.model_class ).inference (clustererNNShadow.inputData_32 , iSize, clustererNNShadow.modelProbabilities_32 );
976+ }
977+ }
967978 if (nnApplication.model_class .getNumOutputNodes ()[0 ][1 ] == 1 ) {
968- runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::determineClass1Labels>({GetGrid (iSize, lane), krnlRunRangeNone}, iSector, clustererNNShadow.nnInferenceInputDType , withMC, batchStart); // Assigning class labels
979+ runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::determineClass1Labels>({GetGrid (iSize, lane), krnlRunRangeNone}, iSector, clustererNNShadow.nnInferenceOutputDType , withMC, batchStart); // Assigning class labels
969980 } else {
970- runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::determineClass2Labels>({GetGrid (iSize, lane), krnlRunRangeNone}, iSector, clustererNNShadow.nnInferenceInputDType , withMC, batchStart); // Assigning class labels
981+ runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::determineClass2Labels>({GetGrid (iSize, lane), krnlRunRangeNone}, iSector, clustererNNShadow.nnInferenceOutputDType , withMC, batchStart); // Assigning class labels
971982 }
972-
973983 if (!clustererNNShadow.nnClusterizerUseCfRegression ) {
974- nnApplication.networkInference (nnApplication.model_reg_1 , clustererNNShadow, iSize, clustererNNShadow.outputDataReg1 , clustererNNShadow.nnInferenceInputDType );
975- runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::publishClass1Regression>({GetGrid (iSize, lane), krnlRunRangeNone}, iSector, clustererNNShadow.nnInferenceInputDType , withMC, batchStart); // Running the NN for regression class 1
984+ // nnApplication.networkInference(nnApplication.model_reg_1, clustererNNShadow, iSize, clustererNNShadow.outputDataReg1, clustererNNShadow.nnInferenceInputDType);
985+ if (clustererNNShadow.nnInferenceInputDType == 0 ) {
986+ if (clustererNNShadow.nnInferenceOutputDType == 0 ) {
987+ (nnApplication.model_reg_1 ).inference (clustererNNShadow.inputData_16 , iSize, clustererNNShadow.outputDataReg1_16 );
988+ } else if (clustererNNShadow.nnInferenceOutputDType == 1 ) {
989+ (nnApplication.model_reg_1 ).inference (clustererNNShadow.inputData_16 , iSize, clustererNNShadow.outputDataReg1_32 );
990+ }
991+ } else if (clustererNNShadow.nnInferenceInputDType == 1 ) {
992+ if (clustererNNShadow.nnInferenceOutputDType == 0 ) {
993+ (nnApplication.model_reg_1 ).inference (clustererNNShadow.inputData_32 , iSize, clustererNNShadow.outputDataReg1_16 );
994+ } else if (clustererNNShadow.nnInferenceOutputDType == 1 ) {
995+ (nnApplication.model_reg_1 ).inference (clustererNNShadow.inputData_32 , iSize, clustererNNShadow.outputDataReg1_32 );
996+ }
997+ }
998+ runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::publishClass1Regression>({GetGrid (iSize, lane), krnlRunRangeNone}, iSector, clustererNNShadow.nnInferenceOutputDType , withMC, batchStart); // Running the NN for regression class 1
976999 if (nnApplication.model_class .getNumOutputNodes ()[0 ][1 ] > 1 && nnApplication.model_reg_2 .isInitialized ()) {
977- nnApplication.networkInference (nnApplication.model_reg_2 , clustererNNShadow, iSize, clustererNNShadow.outputDataReg2 , clustererNNShadow.nnInferenceInputDType );
978- runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::publishClass2Regression>({GetGrid (iSize, lane), krnlRunRangeNone}, iSector, clustererNNShadow.nnInferenceInputDType , withMC, batchStart); // Running the NN for regression class 2
1000+ // nnApplication.networkInference(nnApplication.model_reg_2, clustererNNShadow, iSize, clustererNNShadow.outputDataReg2, clustererNNShadow.nnInferenceInputDType);
1001+ if (clustererNNShadow.nnInferenceInputDType == 0 ) {
1002+ if (clustererNNShadow.nnInferenceOutputDType == 0 ) {
1003+ (nnApplication.model_reg_2 ).inference (clustererNNShadow.inputData_16 , iSize, clustererNNShadow.outputDataReg2_16 );
1004+ } else if (clustererNNShadow.nnInferenceOutputDType == 1 ) {
1005+ (nnApplication.model_reg_2 ).inference (clustererNNShadow.inputData_16 , iSize, clustererNNShadow.outputDataReg2_32 );
1006+ }
1007+ } else if (clustererNNShadow.nnInferenceInputDType == 1 ) {
1008+ if (clustererNNShadow.nnInferenceOutputDType == 0 ) {
1009+ (nnApplication.model_reg_2 ).inference (clustererNNShadow.inputData_32 , iSize, clustererNNShadow.outputDataReg2_16 );
1010+ } else if (clustererNNShadow.nnInferenceOutputDType == 1 ) {
1011+ (nnApplication.model_reg_2 ).inference (clustererNNShadow.inputData_32 , iSize, clustererNNShadow.outputDataReg2_32 );
1012+ }
1013+ }
1014+ runKernel<GPUTPCNNClusterizerKernels, GPUTPCNNClusterizerKernels::publishClass2Regression>({GetGrid (iSize, lane), krnlRunRangeNone}, iSector, clustererNNShadow.nnInferenceOutputDType , withMC, batchStart); // Running the NN for regression class 2
9791015 }
9801016 }
9811017 auto stop1 = std::chrono::high_resolution_clock::now ();
0 commit comments