Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,29 @@ MODEL_RUNNER_HOST=http://localhost:13434 ./model-cli list
- [Model Specification](https://github.com/docker/model-spec/blob/main/spec.md)
- [Community Slack Channel](https://dockercommunity.slack.com/archives/C09H9P5E57B)

### ModelPack Compatibility

Docker Model Runner supports both Docker model-spec artifacts and CNCF ModelPack artifacts stored in OCI registries.

For ModelPack images, Docker Model Runner accepts:

- config media type: `application/vnd.cncf.model.config.v1+json`
- weight layer media types, including:
- `application/vnd.cncf.model.weight.v1.gguf`
- `application/vnd.cncf.model.weight.v1.safetensors`

This means you can pull and run a ModelPack artifact with the same user workflow:

```bash
# Pull from any OCI-compliant registry
docker model pull <registry>/<namespace>/<model>:<tag>

# Run the model
docker model run <registry>/<namespace>/<model>:<tag> "Hello"
```

If you are publishing artifacts for compatibility across tooling, ensure your image config and layer media types follow the ModelPack spec so downstream clients can detect and use the correct format.

## Using the Makefile

This project includes a Makefile to simplify common development tasks. Docker targets require Docker Desktop >= 4.41.0.
Expand Down
5 changes: 3 additions & 2 deletions pkg/distribution/distribution/bundle_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package distribution

import (
"bytes"
"errors"
"os"
"path/filepath"
Expand Down Expand Up @@ -142,8 +143,8 @@ func TestBundle(t *testing.T) {
if err != nil {
t.Fatalf("Failed to read file with expected contents: %v", err)
}
if string(got) != string(expected) {
t.Fatalf("File contents did not match expected contents. Expected: %s, got: %s", expected, got)
if !bytes.Equal(got, expected) {
t.Fatalf("File contents did not match expected contents")
}
}
})
Expand Down
5 changes: 4 additions & 1 deletion pkg/distribution/distribution/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/docker/model-runner/pkg/distribution/internal/mutate"
"github.com/docker/model-runner/pkg/distribution/internal/progress"
"github.com/docker/model-runner/pkg/distribution/internal/store"
"github.com/docker/model-runner/pkg/distribution/modelpack"
"github.com/docker/model-runner/pkg/distribution/oci"
"github.com/docker/model-runner/pkg/distribution/oci/authn"
"github.com/docker/model-runner/pkg/distribution/oci/remote"
Expand Down Expand Up @@ -786,7 +787,9 @@ func checkCompat(image types.ModelArtifact, log *slog.Logger, reference string,
if err != nil {
return err
}
if manifest.Config.MediaType != types.MediaTypeModelConfigV01 && manifest.Config.MediaType != types.MediaTypeModelConfigV02 {
if manifest.Config.MediaType != types.MediaTypeModelConfigV01 &&
manifest.Config.MediaType != types.MediaTypeModelConfigV02 &&
manifest.Config.MediaType != oci.MediaType(modelpack.MediaTypeModelConfigV1) {
return fmt.Errorf("config type %q is unsupported: %w", manifest.Config.MediaType, ErrUnsupportedMediaType)
}

Expand Down
225 changes: 216 additions & 9 deletions pkg/distribution/distribution/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,163 @@ import (
"path/filepath"
"strings"
"testing"
"time"

"github.com/docker/model-runner/pkg/distribution/internal/mutate"
"github.com/docker/model-runner/pkg/distribution/internal/partial"
"github.com/docker/model-runner/pkg/distribution/internal/progress"
"github.com/docker/model-runner/pkg/distribution/internal/testutil"
"github.com/docker/model-runner/pkg/distribution/modelpack"
"github.com/docker/model-runner/pkg/distribution/oci"
"github.com/docker/model-runner/pkg/distribution/oci/reference"
"github.com/docker/model-runner/pkg/distribution/oci/remote"
mdregistry "github.com/docker/model-runner/pkg/distribution/registry"
"github.com/docker/model-runner/pkg/distribution/registry/testregistry"
"github.com/docker/model-runner/pkg/inference/platform"
"github.com/opencontainers/go-digest"
)

var (
testGGUFFile = filepath.Join("..", "assets", "dummy.gguf")
)

type modelPackTestArtifact struct {
rawConfig []byte
layers []oci.Layer
}

func (m *modelPackTestArtifact) Layers() ([]oci.Layer, error) {
return m.layers, nil
}

func (m *modelPackTestArtifact) MediaType() (oci.MediaType, error) {
manifest, err := m.Manifest()
if err != nil {
return "", err
}
return manifest.MediaType, nil
}

func (m *modelPackTestArtifact) Size() (int64, error) {
rawManifest, err := m.RawManifest()
if err != nil {
return 0, err
}
size := int64(len(rawManifest) + len(m.rawConfig))
for _, layer := range m.layers {
layerSize, err := layer.Size()
if err != nil {
return 0, err
}
size += layerSize
}
return size, nil
}

func (m *modelPackTestArtifact) ConfigName() (oci.Hash, error) {
hash, _, err := oci.SHA256(bytes.NewReader(m.rawConfig))
return hash, err
}

func (m *modelPackTestArtifact) ConfigFile() (*oci.ConfigFile, error) {
return nil, errors.New("invalid for model")
}

func (m *modelPackTestArtifact) RawConfigFile() ([]byte, error) {
return m.rawConfig, nil
}

func (m *modelPackTestArtifact) Digest() (oci.Hash, error) {
rawManifest, err := m.RawManifest()
if err != nil {
return oci.Hash{}, err
}
hash, _, err := oci.SHA256(bytes.NewReader(rawManifest))
return hash, err
}

func (m *modelPackTestArtifact) Manifest() (*oci.Manifest, error) {
return partial.ManifestForLayers(m)
}

func (m *modelPackTestArtifact) RawManifest() ([]byte, error) {
manifest, err := m.Manifest()
if err != nil {
return nil, err
}
return json.Marshal(manifest)
}

func (m *modelPackTestArtifact) LayerByDigest(hash oci.Hash) (oci.Layer, error) {
for _, layer := range m.layers {
layerDigest, err := layer.Digest()
if err != nil {
return nil, err
}
if layerDigest == hash {
return layer, nil
}
}
return nil, fmt.Errorf("layer with digest %s not found", hash)
}

func (m *modelPackTestArtifact) LayerByDiffID(hash oci.Hash) (oci.Layer, error) {
for _, layer := range m.layers {
layerDiffID, err := layer.DiffID()
if err != nil {
return nil, err
}
if layerDiffID == hash {
return layer, nil
}
}
return nil, fmt.Errorf("layer with diffID %s not found", hash)
}

func (m *modelPackTestArtifact) GetConfigMediaType() oci.MediaType {
return oci.MediaType(modelpack.MediaTypeModelConfigV1)
}

func newModelPackTestArtifact(t *testing.T, modelFile string) *modelPackTestArtifact {
t.Helper()

layer, err := partial.NewLayer(modelFile, oci.MediaType(modelpack.MediaTypeWeightGGUF))
if err != nil {
t.Fatalf("Failed to create ModelPack layer: %v", err)
}

diffID, err := layer.DiffID()
if err != nil {
t.Fatalf("Failed to get layer DiffID: %v", err)
}

now := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
mp := modelpack.Model{
Descriptor: modelpack.ModelDescriptor{
CreatedAt: &now,
Name: "dummy-modelpack",
},
Config: modelpack.ModelConfig{
Format: "gguf",
ParamSize: "8B",
},
ModelFS: modelpack.ModelFS{
Type: "layers",
DiffIDs: []digest.Digest{digest.Digest(diffID.String())},
},
}

rawConfig, err := json.Marshal(mp)
if err != nil {
t.Fatalf("Failed to marshal ModelPack config: %v", err)
}

return &modelPackTestArtifact{
rawConfig: rawConfig,
layers: []oci.Layer{layer},
}
}

// newTestClient creates a new client configured for testing with plain HTTP enabled.
func newTestClient(storeRootPath string) (*Client, error) {
return NewClient(
Expand Down Expand Up @@ -98,8 +239,8 @@ func TestClientPullModel(t *testing.T) {
t.Fatalf("Failed to read pulled model: %v", err)
}

if string(pulledContent) != string(modelContent) {
t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent)
if !bytes.Equal(pulledContent, modelContent) {
t.Errorf("Pulled model content doesn't match original")
}
})

Expand Down Expand Up @@ -137,8 +278,74 @@ func TestClientPullModel(t *testing.T) {
t.Fatalf("Failed to read pulled model: %v", err)
}

if string(pulledContent) != string(modelContent) {
t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent)
if !bytes.Equal(pulledContent, modelContent) {
t.Errorf("Pulled model content doesn't match original")
}
})

t.Run("pull modelpack artifact", func(t *testing.T) {
tempDir := t.TempDir()

testClient, err := newTestClient(tempDir)
if err != nil {
t.Fatalf("Failed to create client: %v", err)
}

mpTag := registryHost + "/modelpack-test/model:v1.0.0"
ref, err := reference.ParseReference(mpTag)
if err != nil {
t.Fatalf("Failed to parse reference: %v", err)
}

mpModel := newModelPackTestArtifact(t, testGGUFFile)
if err := remote.Write(ref, mpModel, nil, remote.WithPlainHTTP(true)); err != nil {
t.Fatalf("Failed to push ModelPack model: %v", err)
}

if err := testClient.PullModel(t.Context(), mpTag, nil); err != nil {
t.Fatalf("Failed to pull ModelPack model: %v", err)
}

pulledModel, err := testClient.GetModel(mpTag)
if err != nil {
t.Fatalf("Failed to get pulled model: %v", err)
}

ggufPaths, err := pulledModel.GGUFPaths()
if err != nil {
t.Fatalf("Failed to get GGUF paths: %v", err)
}
if len(ggufPaths) != 1 {
t.Fatalf("Unexpected number of GGUF files: %d", len(ggufPaths))
}

pulledContent, err := os.ReadFile(ggufPaths[0])
if err != nil {
t.Fatalf("Failed to read pulled GGUF file: %v", err)
}

originalContent, err := os.ReadFile(testGGUFFile)
if err != nil {
t.Fatalf("Failed to read source GGUF file: %v", err)
}

if !bytes.Equal(pulledContent, originalContent) {
t.Errorf("Pulled ModelPack model content doesn't match original")
}

cfg, err := pulledModel.Config()
if err != nil {
t.Fatalf("Failed to read pulled model config: %v", err)
}
if cfg.GetFormat() != "gguf" {
t.Errorf("Config format = %q, want %q", cfg.GetFormat(), "gguf")
}
if cfg.GetParameters() != "8B" {
t.Errorf("Config parameters = %q, want %q", cfg.GetParameters(), "8B")
}

if _, ok := cfg.(*modelpack.Model); !ok {
t.Errorf("Config type = %T, want *modelpack.Model", cfg)
}
})

Expand Down Expand Up @@ -332,8 +539,8 @@ func TestClientPullModel(t *testing.T) {
t.Fatalf("Failed to read pulled model: %v", err)
}

if string(pulledContent) != string(testModelContent) {
t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, testModelContent)
if !bytes.Equal(pulledContent, testModelContent) {
t.Errorf("Pulled model content doesn't match original")
}

// Create a modified version of the model
Expand Down Expand Up @@ -382,8 +589,8 @@ func TestClientPullModel(t *testing.T) {
t.Fatalf("Failed to read updated pulled model: %v", err)
}

if string(updatedPulledContent) != string(updatedContent) {
t.Errorf("Updated pulled model content doesn't match: got %q, want %q", updatedPulledContent, updatedContent)
if !bytes.Equal(updatedPulledContent, updatedContent) {
t.Errorf("Updated pulled model content doesn't match")
}
})

Expand Down Expand Up @@ -526,7 +733,7 @@ func TestClientPullModel(t *testing.T) {
t.Fatalf("Failed to read pulled model: %v", err)
}

if string(pulledContent) != string(modelContent) {
if !bytes.Equal(pulledContent, modelContent) {
t.Errorf("Pulled model content doesn't match original")
}
})
Expand Down
5 changes: 3 additions & 2 deletions pkg/distribution/distribution/ecr_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package distribution

import (
"bytes"
"os"
"testing"

Expand Down Expand Up @@ -79,8 +80,8 @@ func TestECRIntegration(t *testing.T) {
t.Fatalf("Failed to read pulled model: %v", err)
}

if string(pulledContent) != string(modelContent) {
t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent)
if !bytes.Equal(pulledContent, modelContent) {
t.Errorf("Pulled model content doesn't match original")
}
})

Expand Down
5 changes: 3 additions & 2 deletions pkg/distribution/distribution/gar_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package distribution

import (
"bytes"
"os"
"testing"

Expand Down Expand Up @@ -80,8 +81,8 @@ func TestGARIntegration(t *testing.T) {
t.Fatalf("Failed to read pulled model: %v", err)
}

if string(pulledContent) != string(modelContent) {
t.Errorf("Pulled model content doesn't match original: got %q, want %q", pulledContent, modelContent)
if !bytes.Equal(pulledContent, modelContent) {
t.Errorf("Pulled model content doesn't match original")
}
})

Expand Down