33
44from xrspatial import curvature
55from xrspatial .tests .general_checks import (assert_numpy_equals_cupy ,
6+ assert_numpy_equals_dask_cupy ,
67 assert_numpy_equals_dask_numpy , create_test_raster ,
78 cuda_and_cupy_available , general_output_checks )
89
@@ -87,9 +88,6 @@ def test_numpy_equals_cupy_random_data(random_data):
8788 numpy_agg = create_test_raster (random_data , backend = 'numpy' )
8889 cupy_agg = create_test_raster (random_data , backend = 'cupy' )
8990 assert_numpy_equals_cupy (numpy_agg , cupy_agg , curvature )
90- # NOTE: Dask + GPU code paths don't currently work because of
91- # dask casting cupy arrays to numpy arrays during
92- # https://github.com/dask/dask/issues/4842
9391
9492
9593@pytest .mark .parametrize ("size" , [(2 , 4 ), (10 , 15 )])
@@ -99,3 +97,13 @@ def test_numpy_equals_dask_random_data(random_data):
9997 numpy_agg = create_test_raster (random_data , backend = 'numpy' )
10098 dask_agg = create_test_raster (random_data , backend = 'dask' )
10199 assert_numpy_equals_dask_numpy (numpy_agg , dask_agg , curvature )
100+
101+
102+ @cuda_and_cupy_available
103+ @pytest .mark .parametrize ("size" , [(2 , 4 ), (10 , 15 )])
104+ @pytest .mark .parametrize (
105+ "dtype" , [np .int32 , np .int64 , np .uint32 , np .uint64 , np .float32 , np .float64 ])
106+ def test_numpy_equals_dask_cupy_random_data (random_data ):
107+ numpy_agg = create_test_raster (random_data , backend = 'numpy' )
108+ dask_cupy_agg = create_test_raster (random_data , backend = 'dask+cupy' )
109+ assert_numpy_equals_dask_cupy (numpy_agg , dask_cupy_agg , curvature , atol = 1e-6 , rtol = 1e-6 )
0 commit comments