Skip to content
This repository was archived by the owner on Jan 26, 2026. It is now read-only.

Commit 730b0ab

Browse files
authored
Wave equation example (#8)
1 parent 59a1eee commit 730b0ab

File tree

3 files changed

+241
-5
lines changed

3 files changed

+241
-5
lines changed

examples/wave_equation.py

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
"""
2+
Linear wave equation benchmark
3+
4+
Usage:
5+
6+
Verify solution with 128x128 problem size
7+
8+
.. code-block::
9+
10+
python wave_equation.py
11+
12+
Run a performance test with 1024x1024 problem size.
13+
Runs a fixed number of steps with a small time step.
14+
15+
.. code-block::
16+
17+
python wave_equation.py -n 1024 -t
18+
19+
Run with numpy backend
20+
21+
.. code-block::
22+
23+
python wave_equation.py -b numpy ...
24+
25+
"""
26+
import math
27+
import numpy
28+
import time as time_mod
29+
import argparse
30+
31+
32+
def run(n, backend, benchmark_mode):
33+
if backend == "ddpt":
34+
import ddptensor as np
35+
from ddptensor.numpy import fromfunction
36+
from ddptensor import init, fini, sync
37+
all_axes = [0, 1]
38+
elif backend == "numpy":
39+
import numpy as np
40+
from numpy import fromfunction
41+
init = fini = sync = lambda x = None: None
42+
all_axes = None
43+
else:
44+
raise ValueError(f'Unknown backend: "{backend}"')
45+
46+
print(f'Using backend: {backend}')
47+
init(False)
48+
49+
# constants
50+
h = 1.0
51+
g = 9.81
52+
53+
# domain extent
54+
# NOTE need to be floats
55+
xmin = -1.0
56+
xmax = 1.0
57+
ymin = -1.0
58+
ymax = 1.0
59+
lx = xmax - xmin
60+
ly = ymax - ymin
61+
62+
# grid resolution
63+
nx = n
64+
ny = n
65+
# grid spacing
66+
dx = lx/nx
67+
dy = lx/ny
68+
69+
# export interval
70+
t_export = 0.02
71+
t_end = 1.0
72+
73+
# coordinate arrays
74+
x_t_2d = fromfunction(
75+
lambda i, j: xmin + i*dx + dx/2, (nx, ny), dtype=np.float64)
76+
y_t_2d = fromfunction(
77+
lambda i, j: ymin + j*dy + dy/2, (nx, ny), dtype=np.float64)
78+
79+
T_shape = (nx, ny)
80+
U_shape = (nx + 1, ny)
81+
V_shape = (nx, ny+1)
82+
83+
dofs_T = int(numpy.prod(numpy.asarray(T_shape)))
84+
dofs_U = int(numpy.prod(numpy.asarray(U_shape)))
85+
dofs_V = int(numpy.prod(numpy.asarray(V_shape)))
86+
87+
print(f'Grid size: {nx} x {ny}')
88+
print(f'Elevation DOFs: {dofs_T}')
89+
print(f'Velocity DOFs: {dofs_U + dofs_V}')
90+
print(f'Total DOFs: {dofs_T + dofs_U + dofs_V}')
91+
92+
# prognostic variables: elevation, (u, v) velocity
93+
e = np.full(T_shape, 0.0, np.float64)
94+
u = np.full(U_shape, 0.0, np.float64)
95+
v = np.full(V_shape, 0.0, np.float64)
96+
97+
# auxiliary variables for RK time integration
98+
e1 = np.full(T_shape, 0.0, np.float64)
99+
u1 = np.full(U_shape, 0.0, np.float64)
100+
v1 = np.full(V_shape, 0.0, np.float64)
101+
e2 = np.full(T_shape, 0.0, np.float64)
102+
u2 = np.full(U_shape, 0.0, np.float64)
103+
v2 = np.full(V_shape, 0.0, np.float64)
104+
105+
def exact_elev(t, x_t_2d, y_t_2d, lx, ly):
106+
"""
107+
Exact solution for elevation field.
108+
109+
Returns time-dependent elevation of a 2D standing wave in a rectangular
110+
domain.
111+
"""
112+
amp = 0.5
113+
c = (g * h) ** 0.5
114+
n = 1
115+
sol_x = np.cos(2 * n * math.pi * x_t_2d / lx)
116+
m = 1
117+
sol_y = np.cos(2 * m * math.pi * y_t_2d / ly)
118+
omega = c * math.pi * ((n/lx)**2 + (m/ly)**2)**0.5
119+
# NOTE ddpt fails with scalar computation
120+
sol_t = numpy.cos(2 * omega * t)
121+
return amp * sol_x * sol_y * sol_t
122+
123+
# inital elevation
124+
e[0:nx, 0:ny] = exact_elev(0.0, x_t_2d, y_t_2d, lx, ly)
125+
126+
# compute time step
127+
alpha = 0.5
128+
c = (g * h) ** 0.5
129+
dt = alpha * dx / c
130+
dt = t_export / int(math.ceil(t_export / dt))
131+
nt = int(math.ceil(t_end / dt))
132+
if benchmark_mode:
133+
dt = 1e-5
134+
nt = 100
135+
t_export = dt*25
136+
137+
print(f'Time step: {dt} s')
138+
print(f'Total run time: {t_end} s, {nt} time steps')
139+
140+
sync()
141+
142+
def rhs(u, v, e):
143+
"""
144+
Evaluate right hand side of the equations
145+
"""
146+
# sign convention: positive on rhs
147+
148+
# pressure gradient -g grad(elev)
149+
dudt = -g * (e[1:nx, 0:ny] - e[0:nx-1, 0:ny]) / dx
150+
dvdt = -g * (e[0:nx, 1:ny] - e[0:nx, 0:ny-1]) / dy
151+
152+
# velocity divergence -h div(u)
153+
dedt = -h * ((u[1:nx+1, 0:ny] - u[0:nx, 0:ny]) / dx +
154+
(v[0:nx, 1:ny+1] - v[0:nx, 0:ny]) / dy)
155+
156+
return dudt, dvdt, dedt
157+
158+
def step(u, v, e, u1, v1, e1, u2, v2, e2):
159+
"""
160+
Execute one SSPRK(3,3) time step
161+
"""
162+
dudt, dvdt, dedt = rhs(u, v, e)
163+
u1[1:nx, 0:ny] = u[1:nx, 0:ny] + dt * dudt
164+
v1[0:nx, 1:ny] = v[0:nx, 1:ny] + dt * dvdt
165+
e1[0:nx, 0:ny] = e[0:nx, 0:ny] + dt * dedt
166+
167+
dudt, dvdt, dedt = rhs(u1, v1, e1)
168+
u2[1:nx, 0:ny] = 0.75*u[1:nx, 0:ny] + 0.25*(u1[1:nx, 0:ny] + dt*dudt)
169+
v2[0:nx, 1:ny] = 0.75*v[0:nx, 1:ny] + 0.25*(v1[0:nx, 1:ny] + dt*dvdt)
170+
e2[0:nx, 0:ny] = 0.75*e[0:nx, 0:ny] + 0.25*(e1[0:nx, 0:ny] + dt*dedt)
171+
172+
dudt, dvdt, dedt = rhs(u2, v2, e2)
173+
u[1:nx, 0:ny] = u[1:nx, 0:ny]/3.0 + 2.0/3.0*(u2[1:nx, 0:ny] + dt*dudt)
174+
v[0:nx, 1:ny] = v[0:nx, 1:ny]/3.0 + 2.0/3.0*(v2[0:nx, 1:ny] + dt*dvdt)
175+
e[0:nx, 0:ny] = e[0:nx, 0:ny]/3.0 + 2.0/3.0*(e2[0:nx, 0:ny] + dt*dedt)
176+
177+
t = 0
178+
i_export = 0
179+
next_t_export = 0
180+
initial_v = None
181+
tic = time_mod.perf_counter()
182+
for i in range(nt+1):
183+
sync()
184+
t = i*dt
185+
186+
if t >= next_t_export - 1e-8:
187+
elev_max = float(np.max(e, all_axes))
188+
u_max = float(np.max(u, all_axes))
189+
190+
total_v = float(np.sum(e + h, all_axes)) * dx * dy
191+
if initial_v is None:
192+
initial_v = total_v
193+
diff_v = total_v - initial_v
194+
195+
print(f'{i_export:2d} {i:4d} {t:.3f} elev={elev_max:7.5f} '
196+
f'u={u_max:7.5f} dV={diff_v: 6.3e}')
197+
if elev_max > 1e3 or not math.isfinite(elev_max):
198+
print(f'Invalid elevation value: {elev_max}')
199+
break
200+
i_export += 1
201+
next_t_export = i_export * t_export
202+
sync()
203+
204+
step(u, v, e, u1, v1, e1, u2, v2, e2)
205+
206+
sync()
207+
208+
duration = time_mod.perf_counter() - tic
209+
print(f'Duration: {duration:.2f} s')
210+
211+
e_exact = exact_elev(t, x_t_2d, y_t_2d, lx, ly)
212+
err2 = (e_exact - e) * (e_exact - e) * dx * dy / lx / ly
213+
err_L2 = math.sqrt(float(np.sum(err2, all_axes)))
214+
print(f'L2 error: {err_L2:7.5e}')
215+
216+
if nx == 128 and ny == 128 and not benchmark_mode:
217+
assert numpy.allclose(err_L2, 7.22407e-03)
218+
print('SUCCESS')
219+
220+
fini()
221+
222+
223+
if __name__ == "__main__":
224+
parser = argparse.ArgumentParser(
225+
description='Run wave equation benchmark',
226+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
227+
)
228+
parser.add_argument('-n', '--resolution', type=int, default=128,
229+
help='Number of grid cells in x and y direction.')
230+
parser.add_argument('-t', '--benchmark-mode', action='store_true',
231+
help='Run a fixed number of time steps.')
232+
parser.add_argument('-b', '--backend', type=str, default='ddpt',
233+
choices=['ddpt', 'numpy'],
234+
help='Backend to use.')
235+
args = parser.parse_args()
236+
run(args.resolution, args.backend, args.benchmark_mode)

src/jit/mlir.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -403,11 +403,11 @@ static const char *pass_pipeline =
403403
"expand,memref-expand,arith-bufferize,func-bufferize,func.func(empty-"
404404
"tensor-to-alloc-tensor,scf-bufferize,tensor-bufferize,linalg-"
405405
"bufferize,bufferization-bufferize,linalg-detensorize,tensor-"
406-
"bufferize,finalizing-bufferize,convert-linalg-to-parallel-loops),"
407-
"canonicalize,fold-memref-alias-ops,expand-strided-metadata,convert-"
408-
"math-to-funcs,lower-affine,convert-scf-to-cf,finalize-memref-to-"
409-
"llvm,convert-math-to-llvm,convert-math-to-libm,convert-func-to-llvm,"
410-
"reconcile-unrealized-casts";
406+
"bufferize,finalizing-bufferize,buffer-deallocation,convert-linalg-"
407+
"to-parallel-loops),canonicalize,fold-memref-alias-ops,expand-"
408+
"strided-metadata,convert-math-to-funcs,lower-affine,convert-scf-"
409+
"to-cf,finalize-memref-to-llvm,convert-math-to-llvm,convert-math-to-"
410+
"libm,convert-func-to-llvm,reconcile-unrealized-casts";
411411
JIT::JIT()
412412
: _context(::mlir::MLIRContext::Threading::DISABLED), _pm(&_context),
413413
_verbose(0) {

0 commit comments

Comments
 (0)