@@ -748,36 +748,74 @@ static const std::string gpu_pipeline =
748748 " func.func(convert-parallel-loops-to-gpu),"
749749 // insert-gpu-allocs pass can have client-api = opencl or vulkan args
750750 " func.func(insert-gpu-allocs{in-regions=1}),"
751- // ** imex GPU passes
752- // "drop-regions,"
753- // "canonicalize,"
754- // // "normalize-memrefs,"
755- // // "gpu-decompose-memrefs,"
756- // "func.func(lower-affine),"
757- // "gpu-kernel-outlining,"
758- // "canonicalize,"
759- // "cse,"
760- // // The following set-spirv-* passes can have client-api = opencl or
761- // vulkan
762- // // args
763- // "set-spirv-capabilities{client-api=opencl},"
764- // "gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
765- // "canonicalize,"
766- // "fold-memref-alias-ops,"
767- // "imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
768- // "spirv.module(spirv-lower-abi-attrs),"
769- // "spirv.module(spirv-update-vce),"
770- // // "func.func(llvm-request-c-wrappers),"
771- // "serialize-spirv,"
772- // "expand-strided-metadata,"
773- // "lower-affine,"
774- // "convert-gpu-to-gpux,"
775- // "convert-func-to-llvm,"
776- // "convert-math-to-llvm,"
777- // "convert-gpux-to-llvm,"
778- // "finalize-memref-to-llvm,"
779- // "reconcile-unrealized-casts";
780- // ** nv GPU passes
751+ " drop-regions,"
752+ " canonicalize,"
753+ // "normalize-memrefs,"
754+ // "gpu-decompose-memrefs,"
755+ " func.func(lower-affine),"
756+ " gpu-kernel-outlining,"
757+ " canonicalize,"
758+ " cse,"
759+ // The following set-spirv-* passes can have client-api = opencl or vulkan
760+ // args
761+ " set-spirv-capabilities{client-api=opencl},"
762+ " gpu.module(set-spirv-abi-attrs{client-api=opencl}),"
763+ " canonicalize,"
764+ " fold-memref-alias-ops,"
765+ " imex-convert-gpu-to-spirv{enable-vc-intrinsic=1},"
766+ " spirv.module(spirv-lower-abi-attrs),"
767+ " spirv.module(spirv-update-vce),"
768+ // "func.func(llvm-request-c-wrappers),"
769+ " serialize-spirv,"
770+ " expand-strided-metadata,"
771+ " lower-affine,"
772+ " convert-gpu-to-gpux,"
773+ " convert-func-to-llvm,"
774+ " convert-math-to-llvm,"
775+ " convert-gpux-to-llvm,"
776+ " finalize-memref-to-llvm,"
777+ " reconcile-unrealized-casts" ;
778+
779+ static const std::string cuda_pipeline =
780+ " add-gpu-regions,"
781+ " canonicalize,"
782+ " ndarray-dist,"
783+ " func.func(dist-coalesce),"
784+ " func.func(dist-infer-elementwise-cores),"
785+ " convert-dist-to-standard,"
786+ " canonicalize,"
787+ " overlap-comm-and-compute,"
788+ " add-comm-cache-keys,"
789+ " lower-distruntime-to-idtr,"
790+ " convert-ndarray-to-linalg,"
791+ " canonicalize,"
792+ " func.func(tosa-make-broadcastable),"
793+ " func.func(tosa-to-linalg),"
794+ " func.func(tosa-to-tensor),"
795+ " canonicalize,"
796+ " linalg-fuse-elementwise-ops,"
797+ " arith-expand,"
798+ " memref-expand,"
799+ " arith-bufferize,"
800+ " func-bufferize,"
801+ " func.func(empty-tensor-to-alloc-tensor),"
802+ " func.func(scf-bufferize),"
803+ " func.func(tensor-bufferize),"
804+ " func.func(bufferization-bufferize),"
805+ " func.func(linalg-bufferize),"
806+ " func.func(linalg-detensorize),"
807+ " func.func(tensor-bufferize),"
808+ " region-bufferize,"
809+ " canonicalize,"
810+ " func.func(finalizing-bufferize),"
811+ " imex-remove-temporaries,"
812+ " func.func(convert-linalg-to-parallel-loops),"
813+ " func.func(scf-parallel-loop-fusion),"
814+ // is add-outer-parallel-loop needed?
815+ " func.func(imex-add-outer-parallel-loop),"
816+ " func.func(gpu-map-parallel-loops),"
817+ " func.func(convert-parallel-loops-to-gpu),"
818+ " func.func(insert-gpu-allocs{in-regions=1}),"
781819 " func.func(insert-gpu-copy),"
782820 " drop-regions,"
783821 " canonicalize,"
@@ -799,7 +837,9 @@ static const std::string gpu_pipeline =
799837
800838const std::string _passes (get_text_env (" SHARPY_PASSES" ));
801839static const std::string &pass_pipeline =
802- _passes != " " ? _passes : (useGPU () ? gpu_pipeline : cpu_pipeline);
840+ _passes != " " ? _passes
841+ : (useGPU () ? (useCUDA () ? cuda_pipeline : gpu_pipeline)
842+ : cpu_pipeline);
803843
804844JIT::JIT (const std::string &libidtr)
805845 : _context (::mlir::MLIRContext::Threading::DISABLED), _pm (&_context),
@@ -851,23 +891,24 @@ JIT::JIT(const std::string &libidtr)
851891 _crunnerlib = mlirRoot + " /lib/libmlir_c_runner_utils.so" ;
852892 _runnerlib = mlirRoot + " /lib/libmlir_runner_utils.so" ;
853893 if (!std::ifstream (_crunnerlib)) {
854- throw std::runtime_error (" Cannot find libmlir_c_runner_utils.so " );
894+ throw std::runtime_error (" Cannot find lib: " + _crunnerlib );
855895 }
856896 if (!std::ifstream (_runnerlib)) {
857- throw std::runtime_error (" Cannot find libmlir_runner_utils.so " );
897+ throw std::runtime_error (" Cannot find lib: " + _runnerlib );
858898 }
859899
860900 if (useGPU ()) {
861901 auto gpuxlibstr = get_text_env (" SHARPY_GPUX_SO" );
862902 if (!gpuxlibstr.empty ()) {
863903 _gpulib = std::string (gpuxlibstr);
864904 } else {
865- // auto imexRoot = get_text_env("IMEXROOT");
866- // imexRoot = !imexRoot.empty() ? imexRoot : std::string(CMAKE_IMEX_ROOT);
867- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
868- // _gpulib = imexRoot + "/lib/liblevel-zero-runtime.so";
869- // for nv gpu
870- _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
905+ if (useCUDA ()) {
906+ _gpulib = mlirRoot + " /lib/libmlir_cuda_runtime.so" ;
907+ } else {
908+ auto imexRoot = get_text_env (" IMEXROOT" );
909+ imexRoot = !imexRoot.empty () ? imexRoot : std::string (CMAKE_IMEX_ROOT);
910+ _gpulib = imexRoot + " /lib/liblevel-zero-runtime.so" ;
911+ }
871912 if (!std::ifstream (_gpulib)) {
872913 throw std::runtime_error (" Cannot find lib: " + _gpulib);
873914 }
0 commit comments