Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
a525cc8
fix: use dynamic slice for set and get ops for images
beckermr Mar 24, 2026
9976783
fix: ensure we declare things as jax
beckermr Mar 24, 2026
f557ce3
Merge branch 'main' into jax-api-bounds-again
beckermr Mar 25, 2026
c2f6120
feat: make int bounds have fixed width
beckermr Mar 26, 2026
84f1703
Merge branch 'main' into jax-api-bounds-again
beckermr Mar 26, 2026
46b8200
fix: ensure tests pass for new bounds
beckermr Mar 26, 2026
2849295
fix: get more bugs for new bounds
beckermr Mar 26, 2026
88aeb48
fix: ensure min coords are not static; use deltas for bounds; avoid i…
beckermr Mar 26, 2026
5b695ad
fix: ensure we can use JIT and grad with new bounds
beckermr Mar 26, 2026
c9ae47b
test: ignore methods in api tests
beckermr Mar 26, 2026
4519568
fix: more changes for xmin,ymin being traced
beckermr Mar 26, 2026
f07022b
fix: this is plus one
beckermr Mar 26, 2026
a8b9b9a
feat: mark bounds as static and adjust pytrees accordingly
beckermr Mar 26, 2026
10bef6f
test: add basic tests of vectorized drawing
beckermr Mar 26, 2026
9e72549
style: pre-commit
beckermr Mar 26, 2026
c52b766
test: make tests more robust
beckermr Mar 26, 2026
19399f5
test: add test of rendering a full scene
beckermr Mar 26, 2026
f4899bd
style: pre the commit
beckermr Mar 26, 2026
7fff7a5
test: add more scene drawing tests
beckermr Mar 26, 2026
b92be8a
doc: add doc strings
beckermr Mar 26, 2026
a60fd2f
Update bounds.py
beckermr Mar 27, 2026
15b6fee
Update image.py
beckermr Mar 27, 2026
3db09cd
Apply suggestion from @beckermr
beckermr Mar 27, 2026
0bfc8a8
fix: make strings not crazy
beckermr Mar 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
482 changes: 413 additions & 69 deletions jax_galsim/bounds.py

Large diffs are not rendered by default.

16 changes: 6 additions & 10 deletions jax_galsim/core/wrap_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _block_reduce_loop(sim, nx, ny, nxwrap, nywrap):
return fim


@partial(jax.jit, static_argnames=("xmin", "ymin", "nxwrap", "nywrap"))
@partial(jax.jit, static_argnames=("nxwrap", "nywrap"))
def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap):
# these bits compute how many total blocks we need to cover the image
nx = im.shape[1] // nxwrap
Expand All @@ -81,7 +81,11 @@ def wrap_nonhermitian(im, xmin, ymin, nxwrap, nywrap):
else:
fim = _block_reduce_loop(sim, nx, ny, nxwrap, nywrap)

im = im.at[ymin : ymin + nywrap, xmin : xmin + nxwrap].set(fim)
im = jax.lax.dynamic_update_slice(
im,
fim,
(ymin, xmin),
)
return im


Expand All @@ -98,10 +102,6 @@ def contract_hermitian_x(im):
@partial(
jax.jit,
static_argnames=[
"im_xmin",
"im_ymin",
"wrap_xmin",
"wrap_ymin",
"wrap_nx",
"wrap_ny",
],
Expand All @@ -127,10 +127,6 @@ def contract_hermitian_y(im):
@partial(
jax.jit,
static_argnames=[
"im_xmin",
"im_ymin",
"wrap_xmin",
"wrap_ymin",
"wrap_nx",
"wrap_ny",
],
Expand Down
31 changes: 24 additions & 7 deletions jax_galsim/gsobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def _setup_image(
N = self.getGoodImageSize(1.0)
if odd:
N += 1
bounds = BoundsI(1, N, 1, N)
bounds = BoundsI(xmin=1, deltax=N, ymin=1, deltay=N)
image.resize(bounds)
# Else use the given image as is

Expand Down Expand Up @@ -486,7 +486,7 @@ def _get_new_bounds(self, image, nx, ny, bounds, center):
if image is not None and image.bounds.isDefined():
return image.bounds
elif nx is not None and ny is not None:
b = BoundsI(1, nx, 1, ny)
b = BoundsI(xmin=1, deltax=nx, ymin=1, deltay=ny)
if center is not None:
# this code has to match the code in _setup_image
# for the same branch of the if statement block
Expand Down Expand Up @@ -853,7 +853,14 @@ def drawFFT_makeKImage(self, image):
image_N = jnp.max(
jnp.array(
[
jnp.max(jnp.abs(jnp.array(image.bounds._getinitargs()))) * 2,
jnp.max(
jnp.abs(
jnp.array(
[image.xmin, image.xmax, image.ymin, image.ymax]
)
)
)
* 2,
jnp.max(jnp.array(image.bounds.numpyShape())),
]
)
Expand All @@ -880,7 +887,9 @@ def drawFFT_makeKImage(self, image):
"drawFFT requires an FFT that is too large.", Nk
)

bounds = BoundsI(0, Nk // 2, -Nk // 2, Nk // 2)
bounds = BoundsI(
xmin=0, deltax=Nk // 2 + 1, ymin=-Nk // 2, deltay=2 * (Nk // 2) + 1
)
if image.dtype in (np.complex128, np.float64, np.int32, np.uint32):
kimage = ImageCD(bounds=bounds, scale=dk)
else:
Expand All @@ -895,12 +904,20 @@ def drawFFT_finish(self, image, kimage, wrap_size, add_to_image):
# Wrap the full image to the size we want for the FT.
# Even if N == Nk, this is useful to make this portion properly Hermitian in the
# N/2 column and N/2 row.
bwrap = BoundsI(0, wrap_size // 2, -wrap_size // 2, wrap_size // 2 - 1)
kimage_wrap = kimage._wrap(bwrap, True, False)
bwrap = BoundsI(
xmin=0,
deltax=wrap_size // 2 + 1,
ymin=-wrap_size // 2,
deltay=2 * (wrap_size // 2),
)
kimage_wrap = kimage._wrap(bwrap, True, False, wrap_size)

# Perform the fourier transform.
breal = BoundsI(
-wrap_size // 2, wrap_size // 2 - 1, -wrap_size // 2, wrap_size // 2 - 1
xmin=-wrap_size // 2,
deltax=2 * (wrap_size // 2),
ymin=-wrap_size // 2,
deltay=2 * (wrap_size // 2),
)
kimg_shift = jnp.fft.ifftshift(kimage_wrap.array, axes=(-2,))
real_image_arr = jnp.fft.fftshift(
Expand Down
Loading
Loading