Skip to content
11 changes: 11 additions & 0 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -403,3 +403,14 @@ cog predict --use-replicate-token -i prompt="Hello"
# Multiple environment variables
cog run -e CUDA_VISIBLE_DEVICES=0 -e BATCH_SIZE=32 python train.py
```

# Selecting Ubuntu version for CUDA base image

To select a specific Ubuntu version for the CUDA base image, set the environment variable `COG_UBUNTU_VERSION` before building:

```bash
export COG_UBUNTU_VERSION=22.04
cog build --use-cog-base-image=false
```

If not set, the latest supported Ubuntu version will be used.
12 changes: 10 additions & 2 deletions pkg/config/compatibility.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
_ "embed"
"encoding/json"
"fmt"
"os"
"sort"
"strings"

Expand Down Expand Up @@ -255,13 +256,20 @@ func versionGreater(a string, b string) (bool, error) {

func CUDABaseImageFor(cuda string, cuDNN string) (string, error) {
var images []CUDABaseImage
ubuntuEnv := os.Getenv("COG_UBUNTU_VERSION")
for _, image := range CUDABaseImages {
if version.Matches(cuda, image.CUDA) && image.CuDNN == cuDNN {
images = append(images, image)
if ubuntuEnv == "" || image.Ubuntu == ubuntuEnv {
images = append(images, image)
}
}
}
if len(images) == 0 {
return "", fmt.Errorf("No matching base image for CUDA %s and CuDNN %s", cuda, cuDNN)
ubuntuMsg := ubuntuEnv
if ubuntuEnv == "" {
ubuntuMsg = "any"
}
return "", fmt.Errorf("No matching base image for CUDA %s, CuDNN %s, Ubuntu %s", cuda, cuDNN, ubuntuMsg)
}

sort.Slice(images, func(i, j int) bool {
Expand Down
40 changes: 40 additions & 0 deletions pkg/config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,46 @@ func TestCUDABaseImageTag(t *testing.T) {
require.Equal(t, "nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04", imageTag)
}

func TestCUDABaseImageTagWithUbuntuEnv(t *testing.T) {
// By default, CUDA 12.8 + Python 3.12 should select Ubuntu 24.04
os.Unsetenv("COG_UBUNTU_VERSION")
configDefault := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.12",
CUDA: "12.8.0",
CuDNN: "9",
},
}

err := configDefault.ValidateAndComplete("")
require.NoError(t, err)

imageTag, err := configDefault.CUDABaseImageTag()
require.NoError(t, err)
require.Equal(t, "nvidia/cuda:12.8.0-cudnn-devel-ubuntu24.04", imageTag)

// If COG_UBUNTU_VERSION is set to 22.04, should select Ubuntu 22.04 image
os.Setenv("COG_UBUNTU_VERSION", "22.04")
configEnv := &Config{
Build: &Build{
GPU: true,
PythonVersion: "3.12",
CUDA: "12.8.0",
CuDNN: "9",
},
}

err = configEnv.ValidateAndComplete("")
require.NoError(t, err)

imageTag, err = configEnv.CUDABaseImageTag()
require.NoError(t, err)
require.Equal(t, "nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04", imageTag)

os.Unsetenv("COG_UBUNTU_VERSION")
}

func TestBuildRunItemStringYAML(t *testing.T) {
type BuildWrapper struct {
Build *Build `yaml:"build"`
Expand Down