Skip to content

Commit 81f8067

Browse files
committed
Fix numba-cuda CPU import crash: install only on GPU image
numba-cuda v0.30.0 depends on cuda-bindings which requires libcudart.so at import time. This crashes on the CPU image where no CUDA runtime is installed. Move numba-cuda install into the GPU-only section while keeping the numba upgrade for both CPU and GPU. b/485275559
1 parent d6a167a commit 81f8067

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

Dockerfile.tmpl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ RUN uv pip install --no-build-isolation --no-cache --system "git+https://github.
2828
# b/468367647: Unpin protobuf, version greater than v5.29.5 causes issues with numerous packages
2929
RUN uv pip install --system --force-reinstall --no-cache --no-deps torchtune
3030
RUN uv pip install --system --force-reinstall --no-cache "protobuf==5.29.5"
31-
# b/493600019: Colab base image ships numba/numba-cuda that do not support NumPy 2.4; upgrade both.
32-
RUN uv pip install --system --force-reinstall --no-cache numba numba-cuda
31+
# b/493600019: Colab base image ships numba that does not support NumPy 2.4; upgrade to latest.
32+
RUN uv pip install --system --force-reinstall --no-cache numba
3333

3434
# Adding non-package dependencies:
3535
ADD clean-layer.sh /tmp/clean-layer.sh
@@ -40,6 +40,8 @@ ARG PACKAGE_PATH=/usr/local/lib/python3.12/dist-packages
4040

4141
# Install GPU-specific non-pip packages.
4242
{{ if eq .Accelerator "gpu" }}
43+
# b/493600019: numba-cuda v0.30.0 fixes np.trapz removal in NumPy 2.4 but requires libcudart.so (GPU only).
44+
RUN uv pip install --system --force-reinstall --no-cache numba-cuda
4345
RUN uv pip install --system --no-cache "pycuda"
4446
{{ end }}
4547

0 commit comments

Comments
 (0)