Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Fixed

- Tracer leak in OnGrid/FourierSeries laplacian_with_pml when using helmholtz_solver with checkpoint=False
- FiniteDifferences with non-default accuracy no longer causes pytree mismatch in time-domain simulation (#224)

### Changed

- Migrated from Poetry to uv for dependency management and builds
- Minimum Python version bumped to 3.11
- Upgraded plumkdocs to >=1.0.0 and mkdocstrings to >=1.0.0
- Upgraded jaxdf dependency to >=0.3.0

## [0.2.1] - 2024-09-17

### Changed

- Upgraded `jaxdf` dependency

## [0.2.0] - 2023-12-18
Expand Down
7 changes: 4 additions & 3 deletions jwave/acoustics/time_harmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Union

import jax
import numpy as np
from jax import numpy as jnp
from jax.lax import while_loop
from jax.scipy.sparse.linalg import bicgstab, gmres
Expand Down Expand Up @@ -164,12 +165,12 @@ def _cbs_norm_units(medium, omega, k0, src):
# Store conversion variables
domain = medium.domain
_conversion = {
"dx": jnp.mean(jnp.asarray(domain.dx)),
"dx": float(np.mean(domain.dx)),
"omega": omega,
}

# Set discretization to 1
dx = tuple(map(lambda x: x / _conversion["dx"], domain.dx))
dx = tuple(float(x / _conversion["dx"]) for x in domain.dx)
domain = Domain(domain.N, dx)

# set omega to 1
Expand Down Expand Up @@ -197,7 +198,7 @@ def _cbs_norm_units(medium, omega, k0, src):

def _cbs_unnorm_units(field, conversion):
domain = field.domain
dx = tuple(map(lambda x: x * conversion["dx"], domain.dx))
dx = tuple(float(x * conversion["dx"]) for x in domain.dx)
domain = Domain(domain.N, dx)

return FourierSeries(field.params, domain)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license = "LGPL-3.0-only"
keywords = ["jax", "acoustics", "simulation", "ultrasound", "differentiable-programming"]
requires-python = ">=3.11"
dependencies = [
"jaxdf>=0.2.8",
"jaxdf>=0.3.0",
"matplotlib>=3.0.0",
]
classifiers = [
Expand Down
23 changes: 21 additions & 2 deletions tests/acoustics/test_simulate_wave_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from jax import numpy as jnp

from jwave.acoustics import simulate_wave_propagation
from jwave.geometry import Domain, FourierSeries, Medium, TimeAxis
from jwave import FiniteDifferences
from jwave.acoustics import TimeWavePropagationSettings, simulate_wave_propagation
from jwave.geometry import Domain, FourierSeries, Medium, TimeAxis, circ_mask
from jwave.logger import logger, set_logging_level


Expand Down Expand Up @@ -40,5 +41,23 @@ def test_correct_call():
assert "Starting simulation using FourierSeries code" in log_contents


def test_fd_nondefault_accuracy():
"""Regression test for jwave#224: FD fields with accuracy != 8
must not cause pytree mismatch in lax.scan."""
domain = Domain((64, 64), (1e-3, 1e-3))
p0_arr = 5.0 * circ_mask(domain.N, 3, (32, 32))
p0 = FiniteDifferences(
jnp.expand_dims(p0_arr, -1), domain, accuracy=4)
sound_speed = FiniteDifferences(
jnp.expand_dims(jnp.ones(domain.N) * 1500.0, -1), domain, accuracy=4)
medium = Medium(domain, sound_speed=sound_speed, pml_size=0)
time_axis = TimeAxis.from_medium(medium, cfl=0.1)
time_axis.t_end = 2e-6
settings = TimeWavePropagationSettings(smooth_initial=False)

p = simulate_wave_propagation(medium, time_axis, p0=p0, settings=settings)
assert p is not None


if __name__ == "__main__":
test_correct_call()
4 changes: 3 additions & 1 deletion tests/test_off_grid_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ def test_sensor(nx, ny, nz):
s3d = BLISensors((x + 0.25, y + 0.3, z + 0.1), (nx, ny, nz))
domain3d = Domain((nx, ny, nz), (1, 1, 1))
# Check ones in ones out.
# rtol=1e-4: BLI uses float32 3D FFT interpolation, which has limited
# precision that varies across platforms (different BLAS/FFT backends).
p3d = FourierSeries(np.ones((nx, ny, nz)), domain3d)
y = s3d(p3d, None, None)
assert (np.all(np.isclose(y, 1)))
assert (np.all(np.isclose(y, 1, rtol=1e-4)))

# Check zeros in zeros out
p3d = FourierSeries(np.zeros((nx, ny, nz)), domain3d)
Expand Down
59 changes: 33 additions & 26 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading