ENH: delegate broadcast_shapes#713
Conversation
There was a problem hiding this comment.
Pull request overview
This PR advances the delegation work in #100 by making broadcast_shapes a delegated public API function. It keeps the existing array-agnostic fallback behavior (including support for None and math.nan) while using NumPy’s broadcast_shapes for the fast-path when shapes are fully known integers and NumPy is available.
Changes:
- Moved the public
broadcast_shapeswrapper intoarray_api_extra._delegationand re-exported it fromarray_api_extra.__init__. - Added a NumPy delegation fast-path for fully-integer shapes, preserving the existing fallback for unknown sizes and no-NumPy runtimes.
- Added focused tests to ensure delegation is used only for known-integer shapes and that fallbacks are preserved.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
tests/test_funcs.py |
Adds tests for NumPy delegation and fallback behaviors of broadcast_shapes. |
src/array_api_extra/_lib/_funcs.py |
Replaces the large broadcast_shapes docstring with a delegation reference (implementation unchanged). |
src/array_api_extra/_delegation.py |
Introduces the public broadcast_shapes wrapper and NumPy fast-path delegation. |
src/array_api_extra/__init__.py |
Re-exports broadcast_shapes from the delegation layer instead of _lib._funcs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -221,45 +221,7 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType) -> Array: | |||
| # `float` in signature to accept `math.nan` for Dask. | |||
| # `int`s are still accepted as `float` is a superclass of `int` in typing | |||
| def broadcast_shapes(*shapes: tuple[float | None, ...]) -> tuple[int | None, ...]: | |||
7dd938d to
5939f4b
Compare
| if np is not None and all( | ||
| isinstance(size, int) for shape in shapes for size in shape | ||
| ): | ||
| int_shapes = cast(tuple[tuple[int, ...], ...], shapes) | ||
| return cast(tuple[int | None, ...], np.broadcast_shapes(*int_shapes)) |
There was a problem hiding this comment.
please see how the other functions in this file work — we want to grab xp
array-api-extra/src/array_api_extra/_delegation.py
Lines 220 to 221 in bc126fa
and then use functions from that namespace where possible:
array-api-extra/src/array_api_extra/_delegation.py
Lines 227 to 236 in bc126fa
|
Updated based on Copilot review:\n- Added the numpydoc ignore marker to the internal fallback implementation.\n- Moved the optional NumPy import into a helper so NumPy is not imported at module import time. |
5939f4b to
941bbdb
Compare
|
Thanks for the review. Updated the implementation to follow the existing delegation pattern: |
| and ( | ||
| is_numpy_namespace(xp) | ||
| or is_cupy_namespace(xp) | ||
| or is_dask_namespace(xp) |
There was a problem hiding this comment.
(dask doesn't have this function)
| or is_dask_namespace(xp) |
There was a problem hiding this comment.
Thanks, fixed. Removed Dask from the native broadcast_shapes delegation path since Dask does not provide this function, so Dask now falls back to the generic implementation.
941bbdb to
a0c98d0
Compare
Part of #100.\n\nSummary:\n- Move the public broadcast_shapes wrapper into the delegation layer.\n- Delegate known integer shapes to numpy.broadcast_shapes when NumPy is available.\n- Preserve the existing fallback for None, math.nan, and no-NumPy runtimes.\n\nTesting:\n- pixi run -e tests pytest -q tests/test_funcs.py::TestBroadcastShapes\n- pixi run -e tests tests\n- pixi run -e lint ruff check src/array_api_extra/_delegation.py src/array_api_extra/_lib/_funcs.py src/array_api_extra/init.py tests/test_funcs.py\n- pixi run -e lint ruff format --check src/array_api_extra/_delegation.py src/array_api_extra/_lib/_funcs.py src/array_api_extra/init.py tests/test_funcs.py\n- pixi run -e lint mypy src/array_api_extra/_delegation.py src/array_api_extra/_lib/_funcs.py tests/test_funcs.py\n- pixi run -e lint pyright src/array_api_extra/_delegation.py src/array_api_extra/_lib/_funcs.py tests/test_funcs.py\n- pixi run -e lint pyrefly check src/array_api_extra/_delegation.py src/array_api_extra/_lib/_funcs.py tests/test_funcs.py\n- pixi run -e docs docs