@@ -1932,23 +1932,66 @@ def test_svd(shape, full_matrices, compute_uv, device):
19321932 assert_sycl_queue_equal (dpnp_s_queue , expected_queue )
19331933
19341934
1935- @pytest .mark .parametrize (
1936- "device_from" ,
1937- valid_devices ,
1938- ids = [device .filter_string for device in valid_devices ],
1939- )
1940- @pytest .mark .parametrize (
1941- "device_to" ,
1942- valid_devices ,
1943- ids = [device .filter_string for device in valid_devices ],
1944- )
1945- def test_to_device (device_from , device_to ):
1946- data = [1.0 , 1.0 , 1.0 , 1.0 , 1.0 ]
1947-
1948- x = dpnp .array (data , dtype = dpnp .float32 , device = device_from )
1949- y = x .to_device (device_to )
1935+ class TestToDevice :
1936+ @pytest .mark .parametrize (
1937+ "device_from" ,
1938+ valid_devices ,
1939+ ids = [device .filter_string for device in valid_devices ],
1940+ )
1941+ @pytest .mark .parametrize (
1942+ "device_to" ,
1943+ valid_devices ,
1944+ ids = [device .filter_string for device in valid_devices ],
1945+ )
1946+ def test_basic (self , device_from , device_to ):
1947+ data = [1.0 , 1.0 , 1.0 , 1.0 , 1.0 ]
1948+ x = dpnp .array (data , dtype = dpnp .float32 , device = device_from )
1949+
1950+ y = x .to_device (device_to )
1951+ assert y .sycl_device == device_to
1952+ assert (x .asnumpy () == y .asnumpy ()).all ()
1953+
1954+ def test_to_queue (self ):
1955+ x = dpnp .full (100 , 2 , dtype = dpnp .int64 )
1956+ q_prof = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1957+
1958+ y = x .to_device (q_prof )
1959+ assert (x .asnumpy () == y .asnumpy ()).all ()
1960+ assert_sycl_queue_equal (y .sycl_queue , q_prof )
1961+
1962+ def test_stream (self ):
1963+ x = dpnp .full (100 , 2 , dtype = dpnp .int64 )
1964+ q_prof = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1965+ q_exec = dpctl .SyclQueue (x .sycl_device )
1966+
1967+ y = x .to_device (q_prof , stream = q_exec )
1968+ assert (x .asnumpy () == y .asnumpy ()).all ()
1969+ assert_sycl_queue_equal (y .sycl_queue , q_prof )
1970+
1971+ q_exec = dpctl .SyclQueue (x .sycl_device )
1972+ _ = dpnp .linspace (0 , 20 , num = 10 ** 5 , sycl_queue = q_exec )
1973+ y = x .to_device (q_prof , stream = q_exec )
1974+ assert (x .asnumpy () == y .asnumpy ()).all ()
1975+ assert_sycl_queue_equal (y .sycl_queue , q_prof )
1976+
1977+ def test_stream_no_sync (self ):
1978+ x = dpnp .full (100 , 2 , dtype = dpnp .int64 )
1979+ q_prof = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1980+
1981+ for stream in [None , x .sycl_queue ]:
1982+ y = x .to_device (q_prof , stream = stream )
1983+ assert (x .asnumpy () == y .asnumpy ()).all ()
1984+ assert_sycl_queue_equal (y .sycl_queue , q_prof )
19501985
1951- assert y .sycl_device == device_to
1986+ @pytest .mark .parametrize (
1987+ "stream" ,
1988+ [1 , dict (), dpctl .SyclDevice ()],
1989+ ids = ["scalar" , "dictionary" , "device" ],
1990+ )
1991+ def test_invalid_stream (self , stream ):
1992+ x = dpnp .ones (2 , dtype = dpnp .int64 )
1993+ q_prof = dpctl .SyclQueue (x .sycl_device , property = "enable_profiling" )
1994+ assert_raises (TypeError , x .to_device , q_prof , stream = stream )
19521995
19531996
19541997@pytest .mark .parametrize (
0 commit comments