@@ -384,71 +384,30 @@ int main(int argc, char* argv[]) {
384384 torch::executor::MemoryManager memory_manager (
385385 &method_allocator, &planned_memory, &tmp_allocator);
386386
387- Result<torch::executor::Method> method =
388- program->load_method (method_name, &memory_manager);
389- if (!method.ok ()) {
390- fprintf (
391- stderr,
392- " Loading of method (%s) failed with status %" PRIu32 " ...\n " ,
393- method_name,
394- (unsigned int )method.error ());
395- exit (-1 );
396- }
397- printf (" Method loaded...\n " );
398-
399- Error status = Error::Ok;
400- if (!FLAGS_dataset.empty ()) {
401- // Go through entire dataset for this model.
402- FLAGS_dataset += " /" ;
403- while (dataset = readdir (datasetDir)) {
404- if (!strcmp (dataset->d_name , " ." ) || !strcmp (dataset->d_name , " .." ))
405- continue ;
406-
407- std::vector<std::string> inputsData;
408- inputsData.push_back (FLAGS_dataset + dataset->d_name );
409- // Set input and call inferrence.
410- setInputs (method.get (), inputsData);
411-
412- status = method->execute ();
413- if (status != Error::Ok) {
414- fprintf (
415- stderr,
416- " Execution of method %s failed with status %" PRIu32 " ...\n " ,
417- method_name,
418- (unsigned int )status);
419- exit (-1 );
420- } else {
421- printf (" Method executed successfully...\n " );
422- }
423-
424- // Save outputs in binary files.
425- saveOutputs (method.get (), FLAGS_output, dataset->d_name );
426- // Print result with highest confidence.
427- printOutput (method.get (), FLAGS_output, dataset->d_name );
387+ {
388+ Result<torch::executor::Method> method =
389+ program->load_method (method_name, &memory_manager);
390+ if (!method.ok ()) {
391+ fprintf (
392+ stderr,
393+ " Loading of method (%s) failed with status %" PRIu32 " ...\n " ,
394+ method_name,
395+ (unsigned int )method.error ());
396+ exit (-1 );
428397 }
429- closedir (datasetDir);
430- } else if (!FLAGS_inputs.empty ()) {
431- std::vector<std::string> inputPaths;
432-
433- // Validate and process inputs and separate into two lists.
434- processInputs (inputPaths, FLAGS_inputs);
435-
436- if (std::all_of (inputPaths.begin (), inputPaths.end (), isDirectory)) {
437- // Inputs are in directories - use files in each directory as the inputs.
438- std::vector<std::string> inputsData;
439- for (std::string& inputDir : inputPaths) {
440- datasetDir = opendir (inputDir.c_str ());
441- while (dataset = readdir (datasetDir)) {
442- if (!strcmp (dataset->d_name , " ." ) || !strcmp (dataset->d_name , " .." ))
443- continue ;
444-
445- inputsData.push_back (inputDir + " /" + dataset->d_name );
446- }
447- closedir (datasetDir);
448-
449- // Sort inputsData to ensure correct input ordering
450- std::sort (inputsData.begin (), inputsData.end ());
451-
398+ printf (" Method loaded...\n " );
399+
400+ Error status = Error::Ok;
401+ if (!FLAGS_dataset.empty ()) {
402+ // Go through entire dataset for this model.
403+ FLAGS_dataset += " /" ;
404+ while (dataset = readdir (datasetDir)) {
405+ if (!strcmp (dataset->d_name , " ." ) || !strcmp (dataset->d_name , " .." ))
406+ continue ;
407+
408+ std::vector<std::string> inputsData;
409+ inputsData.push_back (FLAGS_dataset + dataset->d_name );
410+ // Set input and call inferrence.
452411 setInputs (method.get (), inputsData);
453412
454413 status = method->execute ();
@@ -463,37 +422,81 @@ int main(int argc, char* argv[]) {
463422 printf (" Method executed successfully...\n " );
464423 }
465424
466- if (inputDir.back () == ' /' )
467- inputDir.pop_back ();
468-
469- auto pos = inputDir.find_last_of (' /' );
470- if (pos != std::string::npos)
471- inputDir = inputDir.substr (pos + 1 );
472-
473425 // Save outputs in binary files.
474- saveOutputs (method.get (), FLAGS_output, inputDir.c_str ());
475- inputsData.clear ();
426+ saveOutputs (method.get (), FLAGS_output, dataset->d_name );
427+ // Print result with highest confidence.
428+ printOutput (method.get (), FLAGS_output, dataset->d_name );
476429 }
477- } else {
478- // Inputs are files.
479- setInputs (method.get (), inputPaths);
480-
481- status = method->execute ();
482- if (status != Error::Ok) {
483- fprintf (
484- stderr,
485- " Execution of method %s failed with status %" PRIu32 " ...\n " ,
486- method_name,
487- (unsigned int )status);
488- exit (-1 );
430+ closedir (datasetDir);
431+ } else if (!FLAGS_inputs.empty ()) {
432+ std::vector<std::string> inputPaths;
433+
434+ // Validate and process inputs and separate into two lists.
435+ processInputs (inputPaths, FLAGS_inputs);
436+
437+ if (std::all_of (inputPaths.begin (), inputPaths.end (), isDirectory)) {
438+ // Inputs are in directories - use files in each directory as the
439+ // inputs.
440+ std::vector<std::string> inputsData;
441+ for (std::string& inputDir : inputPaths) {
442+ datasetDir = opendir (inputDir.c_str ());
443+ while (dataset = readdir (datasetDir)) {
444+ if (!strcmp (dataset->d_name , " ." ) || !strcmp (dataset->d_name , " .." ))
445+ continue ;
446+
447+ inputsData.push_back (inputDir + " /" + dataset->d_name );
448+ }
449+ closedir (datasetDir);
450+
451+ // Sort inputsData to ensure correct input ordering
452+ std::sort (inputsData.begin (), inputsData.end ());
453+
454+ setInputs (method.get (), inputsData);
455+
456+ status = method->execute ();
457+ if (status != Error::Ok) {
458+ fprintf (
459+ stderr,
460+ " Execution of method %s failed with status %" PRIu32 " ...\n " ,
461+ method_name,
462+ (unsigned int )status);
463+ exit (-1 );
464+ } else {
465+ printf (" Method executed successfully...\n " );
466+ }
467+
468+ if (inputDir.back () == ' /' )
469+ inputDir.pop_back ();
470+
471+ auto pos = inputDir.find_last_of (' /' );
472+ if (pos != std::string::npos)
473+ inputDir = inputDir.substr (pos + 1 );
474+
475+ // Save outputs in binary files.
476+ saveOutputs (method.get (), FLAGS_output, inputDir.c_str ());
477+ inputsData.clear ();
478+ }
489479 } else {
490- printf (" Method executed successfully...\n " );
491- }
480+ // Inputs are files.
481+ setInputs (method.get (), inputPaths);
482+
483+ status = method->execute ();
484+ if (status != Error::Ok) {
485+ fprintf (
486+ stderr,
487+ " Execution of method %s failed with status %" PRIu32 " ...\n " ,
488+ method_name,
489+ (unsigned int )status);
490+ exit (-1 );
491+ } else {
492+ printf (" Method executed successfully...\n " );
493+ }
492494
493- // Save outputs in binary files.
494- saveOutputs (method.get (), FLAGS_output);
495+ // Save outputs in binary files.
496+ saveOutputs (method.get (), FLAGS_output);
497+ }
495498 }
496- }
499+ } // Destruct the method object before destroying the Neutron Device.
497500
498501 printf (" Finished...\n " );
499502
0 commit comments