Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/test_config_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,11 @@ def test_tiled():
gal = galsim.Gaussian(sigma=sigma, flux=flux)
gal.drawImage(stamp)
stamp.addNoise(galsim.GaussianNoise(sigma=0.5, rng=ud))
im1a[stamp.bounds] = stamp
if is_jax_galsim():
# jax-galsim uses the JAX .at API for inplace ops
im1a = im1a.at[stamp.bounds].set(stamp)
else:
im1a[stamp.bounds] = stamp

# Compare to what config builds
im1b = galsim.config.BuildImage(config)
Expand Down
223 changes: 178 additions & 45 deletions tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,16 @@ def test_Image_basic():
value2 = 53 + 12*x - 19*y
if tchar[i] in ['US', 'UI']:
value2 = abs(value2)
im1[x,y] = value2
im2_view[galsim.PositionI(x,y)] = value2
if is_jax_galsim():
# jax-galsim uses the JAX .at API for inplace ops
im1 = im1.at[x,y].set(value2)
else:
im1[x,y] = value2
if is_jax_galsim():
# jax-galsim uses the JAX .at API for inplace ops
im2_view = im2_view.at[galsim.PositionI(x,y)].set(value2)
else:
im2_view[galsim.PositionI(x,y)] = value2
assert im1.getValue(x,y) == value2
assert im1.view().getValue(x=x, y=y) == value2
assert im1.view(make_const=True).getValue(x,y) == value2
Expand Down Expand Up @@ -278,7 +286,11 @@ def test_Image_basic():
else:
value3 = 10*x + y
im1.addValue(x,y, np.int64(value3-value2))
im2_view[x,y] += np.int64(value3-value2)
if is_jax_galsim():
# jax-galsim uses the JAX .at API for inplace ops
im2_view = im2_view.at[x,y].add(np.int64(value3-value2))
else:
im2_view[x,y] += np.int64(value3-value2)
assert im1[galsim.PositionI(x,y)] == value3
assert im1.view()[x,y] == value3
assert im1.view(make_const=True)[galsim.PositionI(x,y)] == value3
Expand All @@ -299,11 +311,19 @@ def test_Image_basic():
assert_raises(galsim.GalSimBoundsError,im1.addValue,0,0,1)
assert_raises(galsim.GalSimBoundsError,im1.__call__,0,0)
assert_raises(galsim.GalSimBoundsError,im1.__getitem__,0,0)
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,0,0,1)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
assert_raises(galsim.GalSimBoundsError,lambda x, y, v: im1.at[x, y].set(v),0,0,1)
else:
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,0,0,1)
assert_raises(galsim.GalSimBoundsError,im1.view().setValue,0,0,1)
assert_raises(galsim.GalSimBoundsError,im1.view().__call__,0,0)
assert_raises(galsim.GalSimBoundsError,im1.view().__getitem__,0,0)
assert_raises(galsim.GalSimBoundsError,im1.view().__setitem__,0,0,1)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
assert_raises(galsim.GalSimBoundsError,lambda x, y, v: im1.view().at[x, y].set(v),0,0,1)
else:
assert_raises(galsim.GalSimBoundsError,im1.view().__setitem__,0,0,1)

assert_raises(galsim.GalSimBoundsError,im1.setValue,ncol+1,0,1)
assert_raises(galsim.GalSimBoundsError,im1.addValue,ncol+1,0,1)
Expand Down Expand Up @@ -344,16 +364,29 @@ def test_Image_basic():
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.setSubImage,galsim.BoundsI(0,ncol+1,0,nrow+1),
galsim.Image(ncol+2,nrow+2, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(0,ncol,1,nrow),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol,0,nrow),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol+1,1,nrow),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol,1,nrow+1),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(0,ncol+1,0,nrow+1),
galsim.Image(ncol+2,nrow+2, init_value=10))
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(0,ncol,1,nrow),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(1,ncol,0,nrow),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(1,ncol+1,1,nrow),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(1,ncol,1,nrow+1),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,lambda b, v: im1.at[b].set(v),galsim.BoundsI(0,ncol+1,0,nrow+1),
galsim.Image(ncol+2,nrow+2, init_value=10))
else:
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(0,ncol,1,nrow),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol,0,nrow),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol+1,1,nrow),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(1,ncol,1,nrow+1),
galsim.Image(ncol+1,nrow, init_value=10))
assert_raises(galsim.GalSimBoundsError,im1.__setitem__,galsim.BoundsI(0,ncol+1,0,nrow+1),
galsim.Image(ncol+2,nrow+2, init_value=10))

# Also, setting values in something that should be const
assert_raises(galsim.GalSimImmutableError,im1.view(make_const=True).setValue,1,1,1)
Expand All @@ -364,9 +397,17 @@ def test_Image_basic():

# Finally check for the wrong number of arguments in get/setitem
assert_raises(TypeError,im1.__getitem__,1)
assert_raises(TypeError,im1.__setitem__,1,1)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
assert_raises(TypeError,lambda b, v: im1.at[b].set(v),1,1)
else:
assert_raises(TypeError,im1.__setitem__,1,1)
assert_raises(TypeError,im1.__getitem__,1,2,3)
assert_raises(TypeError,im1.__setitem__,1,2,3,4)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
assert_raises(TypeError,lambda x, y, z, v: im1.at[x, y, z].set(v),1,2,3,4)
else:
assert_raises(TypeError,im1.__setitem__,1,2,3,4)

# Check view of given data
im3_view = galsim.Image(ref_array.astype(np_array_type))
Expand Down Expand Up @@ -519,8 +560,13 @@ def test_undefined_image():

assert_raises(galsim.GalSimUndefinedBoundsError,im1.setSubImage,galsim.BoundsI(1,2,1,2),
galsim.Image(2,2, init_value=10))
assert_raises(galsim.GalSimUndefinedBoundsError,im1.__setitem__,galsim.BoundsI(1,2,1,2),
galsim.Image(2,2, init_value=10))
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
assert_raises(galsim.GalSimUndefinedBoundsError,lambda b,v: im1.at[b].set(v),galsim.BoundsI(1,2,1,2),
galsim.Image(2,2, init_value=10))
else:
assert_raises(galsim.GalSimUndefinedBoundsError,im1.__setitem__,galsim.BoundsI(1,2,1,2),
galsim.Image(2,2, init_value=10))

im1.scale = 1.
assert_raises(galsim.GalSimUndefinedBoundsError,im1.calculate_fft)
Expand Down Expand Up @@ -2097,7 +2143,11 @@ def test_Image_subImage():
err_msg="image.subImage(bounds) does not match reference for dtype = "+str(types[i]))
np.testing.assert_array_equal(image[bounds].array, sub_array,
err_msg="image[bounds] does not match reference for dtype = "+str(types[i]))
image[bounds] = galsim.Image(sub_array+100)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].set(galsim.Image(sub_array+100))
else:
image[bounds] = galsim.Image(sub_array+100)
np.testing.assert_array_equal(image[bounds].array, (sub_array+100),
err_msg="image[bounds] = im2 does not set correctly for dtype = "+str(types[i]))
for xpos in range(1,test_shape[0]+1):
Expand All @@ -2111,67 +2161,131 @@ def test_Image_subImage():
"image[bounds] = im2 set wrong locations for dtype = "+str(types[i])

image = galsim.Image(ref_array.astype(types[i]))
image[bounds] += 100
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].add(100)
else:
image[bounds] += 100
np.testing.assert_array_equal(image[bounds].array, (sub_array+100),
err_msg="image[bounds] += 100 does not set correctly for dtype = "+str(types[i]))
image[bounds] = galsim.Image(sub_array)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].set(galsim.Image(sub_array))
else:
image[bounds] = galsim.Image(sub_array)
np.testing.assert_array_equal(image.array, ref_array,
err_msg="image[bounds] += 100 set wrong locations for dtype = "+str(types[i]))

image = galsim.Image(ref_array.astype(types[i]))
image[bounds] -= 100
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].subtract(100)
else:
image[bounds] -= 100
np.testing.assert_array_equal(image[bounds].array, (sub_array-100),
err_msg="image[bounds] -= 100 does not set correctly for dtype = "+str(types[i]))
image[bounds] = galsim.Image(sub_array)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].set(galsim.Image(sub_array))
else:
image[bounds] = galsim.Image(sub_array)
np.testing.assert_array_equal(image.array, ref_array,
err_msg="image[bounds] -= 100 set wrong locations for dtype = "+str(types[i]))

image = galsim.Image(ref_array.astype(types[i]))
image[bounds] *= 100
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].multiply(100)
else:
image[bounds] *= 100
np.testing.assert_array_equal(image[bounds].array, (sub_array*100),
err_msg="image[bounds] *= 100 does not set correctly for dtype = "+str(types[i]))
image[bounds] = galsim.Image(sub_array)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].set(galsim.Image(sub_array))
else:
image[bounds] = galsim.Image(sub_array)
np.testing.assert_array_equal(image.array, ref_array,
err_msg="image[bounds] *= 100 set wrong locations for dtype = "+str(types[i]))

image = galsim.Image((100*ref_array).astype(types[i]))
image[bounds] /= 100
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].divide(100)
else:
image[bounds] /= 100
np.testing.assert_array_equal(image[bounds].array, (sub_array),
err_msg="image[bounds] /= 100 does not set correctly for dtype = "+str(types[i]))
image[bounds] = galsim.Image((100*sub_array).astype(types[i]))
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].set(galsim.Image((100*sub_array).astype(types[i])))
else:
image[bounds] = galsim.Image((100*sub_array).astype(types[i]))
np.testing.assert_array_equal(image.array, (100*ref_array),
err_msg="image[bounds] /= 100 set wrong locations for dtype = "+str(types[i]))

im2 = galsim.Image(sub_array)
image = galsim.Image(ref_array.astype(types[i]))
image[bounds] += im2
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].add(im2)
else:
image[bounds] += im2
np.testing.assert_array_equal(image[bounds].array, (2*sub_array),
err_msg="image[bounds] += im2 does not set correctly for dtype = "+str(types[i]))
image[bounds] = galsim.Image(sub_array)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].set(galsim.Image(sub_array))
else:
image[bounds] = galsim.Image(sub_array)
np.testing.assert_array_equal(image.array, ref_array,
err_msg="image[bounds] += im2 set wrong locations for dtype = "+str(types[i]))

image = galsim.Image(2*ref_array.astype(types[i]))
image[bounds] -= im2
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].subtract(im2)
else:
image[bounds] -= im2
np.testing.assert_array_equal(image[bounds].array, sub_array,
err_msg="image[bounds] -= im2 does not set correctly for dtype = "+str(types[i]))
image[bounds] = galsim.Image((2*sub_array).astype(types[i]))
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].set(galsim.Image((2*sub_array).astype(types[i])))
else:
image[bounds] = galsim.Image((2*sub_array).astype(types[i]))
np.testing.assert_array_equal(image.array, (2*ref_array),
err_msg="image[bounds] -= im2 set wrong locations for dtype = "+str(types[i]))

image = galsim.Image(ref_array.astype(types[i]))
image[bounds] *= im2
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].multiply(im2)
else:
image[bounds] *= im2
np.testing.assert_array_equal(image[bounds].array, (sub_array**2),
err_msg="image[bounds] *= im2 does not set correctly for dtype = "+str(types[i]))
image[bounds] = galsim.Image(sub_array)
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].set(galsim.Image(sub_array))
else:
image[bounds] = galsim.Image(sub_array)
np.testing.assert_array_equal(image.array, ref_array,
err_msg="image[bounds] *= im2 set wrong locations for dtype = "+str(types[i]))

image = galsim.Image((2 * ref_array**2).astype(types[i]))
image[bounds] /= im2
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].divide(im2)
else:
image[bounds] /= im2
np.testing.assert_array_equal(image[bounds].array, (2*sub_array),
err_msg="image[bounds] /= im2 does not set correctly for dtype = "+str(types[i]))
image[bounds] = galsim.Image((2*sub_array**2).astype(types[i]))
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
image = image.at[bounds].set(galsim.Image((2*sub_array**2).astype(types[i])))
else:
image[bounds] = galsim.Image((2*sub_array**2).astype(types[i]))
np.testing.assert_array_equal(image.array, (2*ref_array**2),
err_msg="image[bounds] /= im2 set wrong locations for dtype = "+str(types[i]))

Expand Down Expand Up @@ -2728,7 +2842,11 @@ def test_copy():
assert im10b.wcs == im.wcs
assert im10b.bounds == im.bounds
np.testing.assert_array_equal(im10b.array, im.array)
im10b[2,3] = 27
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
im10b = im10b.at[2,3].set(27)
else:
im10b[2,3] = 27
assert im10b(2,3) == 27.
assert im(2,3) != 27.

Expand All @@ -2738,7 +2856,11 @@ def test_copy():
assert im5.bounds == im8.bounds
np.testing.assert_array_equal(im5.array, im8.array)
assert im5(3,8) == 11.
im8[3,8] = 15
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
im8 = im8.at[3,8].set(15)
else:
im8[3,8] = 15
assert im5(3,8) == 11.

assert_raises(TypeError, im5.copyFrom, im8.array)
Expand Down Expand Up @@ -3429,7 +3551,11 @@ def test_wrap():
for i in range(17):
for j in range(23):
val = np.exp(i/7.3) + (j/12.9)**3 # Something randomly complicated...
im[i,j] = val
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
im = im.at[i,j].set(val)
else:
im[i,j] = val
# Find the location in the sub-image for this point.
ii = (i-b.xmin) % (b.xmax-b.xmin+1) + b.xmin
jj = (j-b.ymin) % (b.ymax-b.ymin+1) + b.ymin
Expand Down Expand Up @@ -3463,12 +3589,19 @@ def test_wrap():
# An arbitrary, complicated Hermitian function.
val = np.exp((i/(2.3*M))**2 + 1j*(2.8*i-1.3*j)) + ((2 + 3j*j)/(1.9*N))**3
#val = 2*(i-j)**2 + 3j*(i+j)

im[i,j] = val
if j >= 0:
im2[i,j] = val
if i >= 0:
im3[i,j] = val
if is_jax_galsim():
# jax-galsim uses .at syntax for setting items
im = im.at[i,j].set(val)
if j >= 0:
im2 = im2.at[i,j].set(val)
if i >= 0:
im3 = im3.at[i,j].set(val)
else:
im[i,j] = val
if j >= 0:
im2[i,j] = val
if i >= 0:
im3[i,j] = val

ii = (i-b.xmin) % (b.xmax-b.xmin+1) + b.xmin
jj = (j-b.ymin) % (b.ymax-b.ymin+1) + b.ymin
Expand Down