11use colored:: * ;
2+ use std:: io:: { self , Write } ;
23
34pub fn info ( msg : & str ) {
45 println ! ( "{} {}" , "[xtask]" . green( ) . bold( ) , msg) ;
@@ -202,6 +203,7 @@ pub fn install_env(env: &str) {
202203 error (
203204 "nvidia-smi not found. Please make sure you have an NVIDIA GPU and drivers installed." ,
204205 ) ;
206+ return ;
205207 }
206208
207209 println ! ( ) ;
@@ -210,9 +212,20 @@ pub fn install_env(env: &str) {
210212 ) ;
211213 }
212214 "OpenCL" => {
213- info (
215+ warning (
214216 "The current automatic installation script only supports OpenCL installation for Intel CPU on Windows or Ubuntu systems." ,
215217 ) ;
218+ warning ( "type 'y' to continue, or any other key to exit." ) ;
219+ let mut input = String :: new ( ) ;
220+ io:: stdout ( ) . flush ( ) . unwrap ( ) ;
221+ io:: stdin ( )
222+ . read_line ( & mut input)
223+ . expect ( "Failed to read input" ) ;
224+ if input. trim ( ) . to_lowercase ( ) != "y" {
225+ info ( "Exiting OpenCL installation." ) ;
226+ return ;
227+ }
228+
216229 info ( "Checking if OpenCL is already installed..." ) ;
217230 println ! ( ) ;
218231
@@ -257,9 +270,11 @@ pub fn install_env(env: &str) {
257270 }
258271 } else {
259272 error ( "Failed to parse the number of platforms." ) ;
273+ return ;
260274 }
261275 } else {
262276 error ( "Failed to find 'Number of platforms' in clinfo output." ) ;
277+ return ;
263278 }
264279 }
265280 #[ cfg( not( target_os = "windows" ) ) ]
@@ -279,7 +294,7 @@ pub fn install_env(env: &str) {
279294 info ( "Installing clinfo tool..." ) ;
280295 let install_status = std:: process:: Command :: new ( "sh" )
281296 . arg ( "-c" )
282- . arg ( "sudo apt update && sudo apt install opencl-headers ocl-icd-opencl-dev -y" )
297+ . arg ( "sudo apt update && sudo apt install clinfo -y" )
283298 . status ( ) ;
284299
285300 if let Err ( e) = install_status {
@@ -313,9 +328,11 @@ pub fn install_env(env: &str) {
313328 }
314329 } else {
315330 error ( "Failed to parse the number of platforms." ) ;
331+ return ;
316332 }
317333 } else {
318334 error ( "Failed to find 'Number of platforms' in clinfo output." ) ;
335+ return ;
319336 }
320337 }
321338
0 commit comments