Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,13 @@
CodeGenTestSettings: Final[dict[str, dict[str, list[str]]]] = {
"internal": {"extras": ["jax"], "markers": ["not requires_dace"]}
}
CodeGenDaceTestSettings = CodeGenTestSettings | {
"dace": {"extras": [], "markers": ["requires_dace"]},
# Use dace-cartesian group to select the appropriate dace version
CodeGenCartesianTestSettings = CodeGenTestSettings | {
"dace": {"extras": [], "groups": ["dace-cartesian"], "markers": ["requires_dace"]},
}
# Install dace-next group to select the appropriate dace version
CodeGenNextTestSettings = CodeGenTestSettings | {
"dace": {"extras": [], "groups": ["dace-next"], "markers": ["requires_dace"]},
}


Expand Down Expand Up @@ -162,7 +167,7 @@ def test_cartesian(
) -> None:
"""Run selected 'gt4py.cartesian' tests."""

codegen_settings = CodeGenDaceTestSettings[codegen]
codegen_settings = CodeGenCartesianTestSettings[codegen]
device_settings = DeviceTestSettings[device]
extras = [
"standard",
Expand Down Expand Up @@ -245,7 +250,7 @@ def test_next(
) -> None:
"""Run selected 'gt4py.next' tests."""

codegen_settings = CodeGenDaceTestSettings[codegen]
codegen_settings = CodeGenNextTestSettings[codegen]
device_settings = DeviceTestSettings[device]
extras = [
"standard",
Expand Down
16 changes: 14 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ requires = ['cython>=3.0.0', 'setuptools>=77.0.3', 'versioningit>=3.1.1', 'wheel
# -- Dependency groups --
[dependency-groups]
build = ['cython>=3.0.0', 'pip>=22.1.1', 'setuptools>=77.0.3', 'wheel>=0.33.6']
dace-cartesian = [
'dace>=1.0.2' # refined in [tool.uv.sources]
]
dace-next = [
'dace==43!2026.04.27' # uses custom index at 'https://github.com/GridTools/pypi'
]
dev = [
{include-group = 'build'},
{include-group = 'docs'},
Expand Down Expand Up @@ -100,7 +106,6 @@ dependencies = [
'click>=8.0.0',
'cmake>=3.22',
'cytoolz>=1.0.1',
'dace>=2.0.0a3',
'deepdiff>=8.1.0',
'devtools>=0.6',
'factory-boy>=3.3.3',
Expand Down Expand Up @@ -457,6 +462,10 @@ conflicts = [
{extra = 'jax-cuda13'},
{extra = 'rocm6'},
{extra = 'rocm7'}
],
[
{group = 'dace-cartesian'},
{group = 'dace-next'}
]
]
default-groups = ["dev"]
Expand All @@ -473,9 +482,12 @@ name = 'gridtools'
url = 'https://gridtools.github.io/pypi/'

# Add the uv source below to pull dace from the gridtools index instead of PyPI:
# dace = {index = "gridtools"}
[tool.uv.sources]
atlas4py = {index = "test.pypi"}
dace = [
{git = "https://github.com/romanc/dace", branch = "romanc/math-functions", group = "dace-cartesian"},
{index = "gridtools", group = "dace-next"}
]

# -- versioningit --
[tool.versioningit]
Expand Down
40 changes: 33 additions & 7 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _sdfg_add_arrays_and_edges(
inputs: set[str] | dict[str, dtypes.typeclass],
outputs: set[str] | dict[str, dtypes.typeclass],
origins: dict[str, tuple[int, ...]],
domain: tuple[int, ...],
) -> None:
for name, array in inner_sdfg.arrays.items():
if array.transient:
Expand Down Expand Up @@ -129,12 +130,20 @@ def _sdfg_add_arrays_and_edges(
if axis not in axes:
continue
o = origin[index]
e = field_info[name].boundary.lower_indices[cartesian_index]
lower, upper = field_info[name].boundary[cartesian_index]
s = inner_sdfg.arrays[name].shape[index]
ranges.append(
# s - 1 because ranges are inclusive
(o - max(0, e), o - max(0, e) + s - 1, 1)
)
if axis == CartesianSpace.Axis.K.name:
d = domain[cartesian_index]
ranges.append(
# max(0, lower) because ...
# d - 1 because ranges are inclusive
(o - max(0, lower), o + upper + d - 1, 1)
)
else:
ranges.append(
# s - 1 because ranges are inclusive
(o - max(0, lower), o - max(0, lower) + s - 1, 1)
)
index += 1

# Add data dimensions to the range
Expand Down Expand Up @@ -264,7 +273,7 @@ def freeze_origin_domain_sdfg(
nsdfg = state.add_nested_sdfg(inner_sdfg, inputs, outputs)

_sdfg_add_arrays_and_edges(
field_info, wrapper_sdfg, state, inner_sdfg, nsdfg, inputs, outputs, origin
field_info, wrapper_sdfg, state, inner_sdfg, nsdfg, inputs, outputs, origin, domain
)

# in special case of empty domain, remove entire SDFG.
Expand Down Expand Up @@ -920,7 +929,7 @@ def generate_extension(self) -> None:

@register
class DaceGPUBackend(BaseDaceBackend):
"""DaCe python backend using gt4py.cartesian.gtc."""
"""GPU DaCe python with an optimal KJI loop layout"""

name = "dace:gpu"
languages: ClassVar[dict] = {"computation": "cuda", "bindings": ["python"]}
Expand All @@ -933,3 +942,20 @@ class DaceGPUBackend(BaseDaceBackend):

def generate_extension(self) -> None:
return self.make_extension(uses_cuda=True)


@register
class DaceGPUBackendIJK(BaseDaceBackend):
"""GPU DaCe python with an optimal IJK loop layout"""

name = "dace:gpu_IJK"
languages: ClassVar[dict] = {"computation": "cuda", "bindings": ["python"]}
storage_info: ClassVar[layout.LayoutInfo] = layout_registry.from_name(name)
MODULE_GENERATOR_CLASS = DaCeCUDAPyExtModuleGenerator
options: ClassVar[GTBackendOptions] = {
**BaseGTBackend.GT_BACKEND_OPTS,
"device_sync": {"versioning": True, "type": bool},
}

def generate_extension(self) -> None:
return self.make_extension(uses_cuda=True)
28 changes: 14 additions & 14 deletions src/gt4py/cartesian/gtc/dace/oir_to_tasklet.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,38 +226,38 @@ def visit_NativeFunction(self, node: common.NativeFunction, **_kwargs: Any) -> s
common.NativeFunction.ABS: "abs",
common.NativeFunction.MIN: "min",
common.NativeFunction.MAX: "max",
common.NativeFunction.MOD: "fmod",
common.NativeFunction.MOD: "dace.math.fmod",
common.NativeFunction.SIN: "dace.math.sin",
common.NativeFunction.COS: "dace.math.cos",
common.NativeFunction.TAN: "dace.math.tan",
common.NativeFunction.ARCSIN: "asin",
common.NativeFunction.ARCCOS: "acos",
common.NativeFunction.ARCTAN: "atan",
common.NativeFunction.ARCSIN: "dace.math.asin",
common.NativeFunction.ARCCOS: "dace.math.acos",
common.NativeFunction.ARCTAN: "dace.math.atan",
common.NativeFunction.SINH: "dace.math.sinh",
common.NativeFunction.COSH: "dace.math.cosh",
common.NativeFunction.TANH: "dace.math.tanh",
common.NativeFunction.ARCSINH: "asinh",
common.NativeFunction.ARCCOSH: "acosh",
common.NativeFunction.ARCTANH: "atanh",
common.NativeFunction.ARCSINH: "dace.math.asinh",
common.NativeFunction.ARCCOSH: "dace.math.acosh",
common.NativeFunction.ARCTANH: "dace.math.atanh",
common.NativeFunction.SQRT: "dace.math.sqrt",
common.NativeFunction.POW: "dace.math.pow",
common.NativeFunction.EXP: "dace.math.exp",
common.NativeFunction.LOG: "dace.math.log",
common.NativeFunction.LOG10: "log10",
common.NativeFunction.GAMMA: "tgamma",
common.NativeFunction.CBRT: "cbrt",
common.NativeFunction.LOG10: "dace.math.log10",
common.NativeFunction.GAMMA: "dace.math.tgamma",
common.NativeFunction.CBRT: "dace.math.cbrt",
common.NativeFunction.ISFINITE: "isfinite",
common.NativeFunction.ISINF: "isinf",
common.NativeFunction.ISNAN: "isnan",
common.NativeFunction.FLOOR: "dace.math.ifloor",
common.NativeFunction.CEIL: "ceil",
common.NativeFunction.TRUNC: "trunc",
common.NativeFunction.CEIL: "dace.math.ceil",
common.NativeFunction.TRUNC: "dace.math.trunc",
common.NativeFunction.INT32: "dace.int32",
common.NativeFunction.INT64: "dace.int64",
common.NativeFunction.FLOAT32: "dace.float32",
common.NativeFunction.FLOAT64: "dace.float64",
common.NativeFunction.ERF: "erf",
common.NativeFunction.ERFC: "erfc",
common.NativeFunction.ERF: "dace.math.erf",
common.NativeFunction.ERFC: "dace.math.erfc",
common.NativeFunction.ROUND: "nearbyint",
common.NativeFunction.ROUND_AWAY_FROM_ZERO: "round",
}
Expand Down
4 changes: 1 addition & 3 deletions src/gt4py/cartesian/gtc/dace/oir_to_treeir.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ def visit_Stencil(self, node: oir.Stencil) -> tir.TreeRoot:
param,
field_without_mask_extents[param.name],
k_bound,
symbols,
),
strides=get_dace_strides(param, symbols),
storage=DEFAULT_STORAGE_TYPE[self._device_type],
Expand All @@ -374,7 +373,7 @@ def visit_Stencil(self, node: oir.Stencil) -> tir.TreeRoot:
# than persistent will yield issues with memory leaks.
containers[field.name] = data.Array(
dtype=utils.data_type_to_dace_typeclass(field.dtype),
shape=get_dace_shape(field, field_extent, k_bound, symbols),
shape=get_dace_shape(field, field_extent, k_bound),
strides=get_dace_strides(field, symbols),
transient=True,
lifetime=dtypes.AllocationLifetime.Persistent,
Expand Down Expand Up @@ -532,7 +531,6 @@ def get_dace_shape(
field: oir.FieldDecl,
extent: definitions.Extent,
k_bound: tuple[int, int],
symbols: tir.SymbolDict,
) -> list[symbolic.symbol]:
shape = []
for index, axis in enumerate(tir.Axis.dims_3d()):
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/cartesian/utils/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class GPUCompilerName(enum.Enum):
class GPUConfiguration:
name: GPUCompilerName
"""Name identifier of the compiler"""
gpu_compile_flags: list[str]
gpu_compile_flags: str
"""Compile flags for device code"""
binary_path: str
"""Path to binaries for GPU compiler & tools"""
Expand Down Expand Up @@ -181,7 +181,7 @@ def gpu_configuration(optimization_level: str) -> GPUConfiguration:

return GPUConfiguration(
name=name,
gpu_compile_flags=gpu_compile_flags,
gpu_compile_flags=" ".join(gpu_compile_flags).strip(),
binary_path=os.path.join(cuda_root, "bin"),
include_path=os.path.join(cuda_root, "include"),
library_path=library_path,
Expand Down
9 changes: 9 additions & 0 deletions src/gt4py/storage/cartesian/layout_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ def register(name: str, info: LayoutInfo) -> None:
is_optimal_layout=layout_checker_factory(layout_maker_factory((2, 1, 0))),
),
)
register(
"dace:gpu_IJK",
LayoutInfo(
alignment=32,
device="gpu",
layout_map=layout_maker_factory((0, 1, 2)),
is_optimal_layout=layout_checker_factory(layout_maker_factory((0, 1, 2))),
),
)
register(
"debug",
LayoutInfo(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,10 +800,7 @@ class TestVariableKAndReadOutside(gt_testing.StencilTestSuite):

def definition(field_in, field_out, index):
with computation(PARALLEL), interval(1, None):
field_out[0, 0, 0] = (
field_in[0, 0, index] # noqa: F841 [unused-variable]
+ field_in[0, 0, -2]
)
field_out[0, 0, 0] = field_in[0, 0, index] + field_in[0, 0, -2]

def validation(field_in, field_out, index, *, domain, origin):
idx = 1 + (np.arange(domain[-1]) + index)[1:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,22 @@ def test_integer_power_of_integer() -> None:
tasklet_code = visitor.visit_NativeFuncCall(pow_call, ctx=fake_context, is_target=False)

assert "ipow" not in tasklet_code


@pytest.mark.parametrize(
"arg",
[
oir.Literal(value="2", dtype=common.DataType.FLOAT32),
oir.Literal(value="2", dtype=common.DataType.FLOAT64),
],
)
def test_log10_respects_floating_point_precision(arg: oir.Literal) -> None:
log10_call = oir.NativeFuncCall(func=common.NativeFunction.LOG10, args=[arg])

visitor = oir_to_tasklet.OIRToTasklet()
fake_context = oir_to_tasklet.Context(
code="asdf", targets=set(), inputs={}, outputs={}, tree=None, scope=None
)
tasklet_code = visitor.visit_NativeFuncCall(log10_call, ctx=fake_context, is_target=False)

assert "dace.math.log10" in tasklet_code
Loading
Loading