Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions pkg/backend/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ import (
func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) error {
logrus.Infof("fetch: fetching from %s", target)

// Apply default hooks when caller leaves it unset to avoid nil deref.
if cfg.Hooks == nil {
defaults := config.NewFetch()
cfg.Hooks = defaults.Hooks
}
Comment thread
chlins marked this conversation as resolved.

// fetchByDragonfly is called if a Dragonfly endpoint is specified in the configuration.
if cfg.DragonflyEndpoint != "" {
logrus.Infof("fetch: using dragonfly for %s", target)
Expand Down Expand Up @@ -117,11 +123,19 @@ func (b *backend) Fetch(ctx context.Context, target string, cfg *config.Fetch) e
}

logrus.Debugf("fetch: processing layer %s", layer.Digest)
if cfg.Hooks.BeforePullLayer(layer, manifest) {
logrus.Debugf("fetch: layer %s skipped by hook", layer.Digest)
pb.Complete(layer.Digest.String(), fmt.Sprintf("%s %s", internalpb.NormalizePrompt("Skipped blob"), layer.Digest.String()))
cfg.Hooks.AfterPullLayer(layer, true, nil)
return nil
}
if err := tracker.TrackTransfer(func() error {
return pullAndExtractFromRemote(ctx, pb, internalpb.NormalizePrompt("Fetching blob"), client, cfg.Output, layer, tracker)
}); err != nil {
cfg.Hooks.AfterPullLayer(layer, false, err)
return err
}
cfg.Hooks.AfterPullLayer(layer, false, nil)

logrus.Debugf("fetch: successfully processed layer %s", layer.Digest)
return nil
Expand Down
9 changes: 7 additions & 2 deletions pkg/backend/fetch_by_d7y.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,14 @@ func (b *backend) fetchByDragonfly(ctx context.Context, target string, cfg *conf
func fetchLayerByDragonfly(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Fetch) error {
err := retry.Do(func() error {
logrus.Debugf("fetch: processing layer %s", desc.Digest)
cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook
if cfg.Hooks.BeforePullLayer(desc, manifest) {
logrus.Debugf("fetch: layer %s skipped by hook", desc.Digest)
pb.Complete(desc.Digest.String(), fmt.Sprintf("%s %s", internalpb.NormalizePrompt("Skipped blob"), desc.Digest.String()))
cfg.Hooks.AfterPullLayer(desc, true, nil)
return nil
}
err := downloadAndExtractFetchLayer(ctx, pb, client, ref, desc, authToken, cfg)
cfg.Hooks.AfterPullLayer(desc, err) // Call after hook
cfg.Hooks.AfterPullLayer(desc, false, err) // Call after hook
if err != nil {
err = fmt.Errorf("pull: failed to download and extract layer %s: %w", desc.Digest, err)
logrus.Error(err)
Expand Down
182 changes: 182 additions & 0 deletions pkg/backend/fetch_hooks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
/*
* Copyright 2025 The ModelPack Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package backend

import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"os"
"strings"
"sync"
"sync/atomic"
"testing"

modelspec "github.com/modelpack/model-spec/specs-go/v1"
godigest "github.com/opencontainers/go-digest"
ocispec "github.com/opencontainers/image-spec/specs-go/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/modelpack/modctl/pkg/config"
)

// recordingFetchHook tracks hook invocations and can request specific layers
// to be skipped by digest.
type recordingFetchHook struct {
mu sync.Mutex
skipDigests map[string]bool
beforeCount int32
afterCalls []afterFetchCall
}

type afterFetchCall struct {
digest string
skipped bool
err error
}

func (r *recordingFetchHook) BeforePullLayer(desc ocispec.Descriptor, _ ocispec.Manifest) bool {
atomic.AddInt32(&r.beforeCount, 1)
r.mu.Lock()
defer r.mu.Unlock()
return r.skipDigests[desc.Digest.String()]
}

func (r *recordingFetchHook) AfterPullLayer(desc ocispec.Descriptor, skipped bool, err error) {
r.mu.Lock()
defer r.mu.Unlock()
r.afterCalls = append(r.afterCalls, afterFetchCall{
digest: desc.Digest.String(),
skipped: skipped,
err: err,
})
}

// startFetchTestServer spins up an HTTP server that serves a manifest with
// two layers and tracks how many times each blob is requested.
func startFetchTestServer(t *testing.T) (server *httptest.Server, file1Digest, file2Digest godigest.Digest, blobHits map[string]*int32) {
t.Helper()

const (
file1Content = "file1 content..."
file2Content = "file2 content..."
)
file1Digest = godigest.FromString(file1Content)
file2Digest = godigest.FromString(file2Content)

hits := map[string]*int32{
file1Digest.String(): new(int32),
file2Digest.String(): new(int32),
}

server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v2/":
w.WriteHeader(http.StatusOK)
case "/v2/test/model/manifests/latest":
manifest := ocispec.Manifest{
Layers: []ocispec.Descriptor{
{
MediaType: "application/octet-stream.raw",
Digest: file1Digest,
Size: int64(len(file1Content)),
Annotations: map[string]string{
modelspec.AnnotationFilepath: "file1.txt",
},
},
{
MediaType: "application/octet-stream.raw",
Digest: file2Digest,
Size: int64(len(file2Content)),
Annotations: map[string]string{
modelspec.AnnotationFilepath: "file2.txt",
},
},
},
}
w.Header().Set("Content-Type", "application/json")
require.NoError(t, json.NewEncoder(w).Encode(manifest))
case fmt.Sprintf("/v2/test/model/blobs/%s", file1Digest):
atomic.AddInt32(hits[file1Digest.String()], 1)
_, err := w.Write([]byte(file1Content))
require.NoError(t, err)
case fmt.Sprintf("/v2/test/model/blobs/%s", file2Digest):
atomic.AddInt32(hits[file2Digest.String()], 1)
_, err := w.Write([]byte(file2Content))
require.NoError(t, err)
default:
t.Logf("Unexpected request to %s", r.URL.Path)
w.WriteHeader(http.StatusNotFound)
}
}))

return server, file1Digest, file2Digest, hits
}

// TestFetch_HookSkipShortCircuitsLayer verifies that returning skip=true from
// BeforePullLayer prevents the blob from being downloaded and that
// AfterPullLayer is still invoked with skipped=true.
func TestFetch_HookSkipShortCircuitsLayer(t *testing.T) {
tempDir, err := os.MkdirTemp("", "fetch-hook-test")
require.NoError(t, err)
defer os.RemoveAll(tempDir)

server, file1Digest, file2Digest, hits := startFetchTestServer(t)
defer server.Close()

hook := &recordingFetchHook{
skipDigests: map[string]bool{file1Digest.String(): true},
}

b := &backend{}
url := strings.TrimPrefix(server.URL, "http://")
cfg := &config.Fetch{
Output: tempDir,
Patterns: []string{"*.txt"},
PlainHTTP: true,
Concurrency: 2,
Hooks: hook,
}

require.NoError(t, b.Fetch(context.Background(), url+"/test/model:latest", cfg))

// file1 must NOT have been downloaded; file2 must have been.
assert.Equal(t, int32(0), atomic.LoadInt32(hits[file1Digest.String()]),
"skipped layer should not be fetched from remote")
assert.Equal(t, int32(1), atomic.LoadInt32(hits[file2Digest.String()]),
"non-skipped layer should be fetched once")

// BeforePullLayer fires for both layers exactly once (no retries on success).
assert.Equal(t, int32(2), atomic.LoadInt32(&hook.beforeCount))

// AfterPullLayer must be invoked for both layers, with proper skipped flag.
hook.mu.Lock()
defer hook.mu.Unlock()
require.Len(t, hook.afterCalls, 2)

byDigest := map[string]afterFetchCall{}
for _, c := range hook.afterCalls {
byDigest[c.digest] = c
}
assert.True(t, byDigest[file1Digest.String()].skipped, "file1 should be marked skipped")
assert.NoError(t, byDigest[file1Digest.String()].err)
assert.False(t, byDigest[file2Digest.String()].skipped, "file2 should not be marked skipped")
assert.NoError(t, byDigest[file2Digest.String()].err)
}
17 changes: 14 additions & 3 deletions pkg/backend/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ import (
func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) error {
logrus.Infof("pull: pulling artifact %s", target)

// Apply default hooks when caller leaves it unset to avoid nil deref.
if cfg.Hooks == nil {
defaults := config.NewPull()
cfg.Hooks = defaults.Hooks
}
Comment thread
chlins marked this conversation as resolved.

// pullByDragonfly is called if a Dragonfly endpoint is specified in the configuration.
if cfg.DragonflyEndpoint != "" {
logrus.Infof("pull: using dragonfly for %s", target)
Expand Down Expand Up @@ -118,13 +124,18 @@ func (b *backend) Pull(ctx context.Context, target string, cfg *config.Pull) err

return retry.Do(func() error {
logrus.Debugf("pull: processing layer %s", layer.Digest)
// call the before hook.
cfg.Hooks.BeforePullLayer(layer, manifest)
// call the before hook; allow caller to skip this layer.
if cfg.Hooks.BeforePullLayer(layer, manifest) {
logrus.Debugf("pull: layer %s skipped by hook", layer.Digest)
pb.Complete(layer.Digest.String(), fmt.Sprintf("%s %s", internalpb.NormalizePrompt("Skipped blob"), layer.Digest.String()))
cfg.Hooks.AfterPullLayer(layer, true, nil)
return nil
}
err := tracker.TrackTransfer(func() error {
return fn(layer)
})
// call the after hook.
cfg.Hooks.AfterPullLayer(layer, err)
cfg.Hooks.AfterPullLayer(layer, false, err)
if err != nil {
err = fmt.Errorf("pull: failed to process layer %s: %w", layer.Digest, err)
logrus.Error(err)
Expand Down
9 changes: 7 additions & 2 deletions pkg/backend/pull_by_d7y.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,14 @@ func buildBlobURL(ref Referencer, plainHTTP bool, digest string) string {
func processLayer(ctx context.Context, pb *internalpb.ProgressBar, client dfdaemon.DfdaemonDownloadClient, ref Referencer, manifest ocispec.Manifest, desc ocispec.Descriptor, authToken string, cfg *config.Pull) error {
err := retry.Do(func() error {
logrus.Debugf("pull: processing layer %s", desc.Digest)
cfg.Hooks.BeforePullLayer(desc, manifest) // Call before hook
if cfg.Hooks.BeforePullLayer(desc, manifest) {
logrus.Debugf("pull: layer %s skipped by hook", desc.Digest)
pb.Complete(desc.Digest.String(), fmt.Sprintf("%s %s", internalpb.NormalizePrompt("Skipped blob"), desc.Digest.String()))
cfg.Hooks.AfterPullLayer(desc, true, nil)
return nil
}
err := downloadAndExtractLayer(ctx, pb, client, ref, desc, authToken, cfg)
cfg.Hooks.AfterPullLayer(desc, err) // Call after hook
cfg.Hooks.AfterPullLayer(desc, false, err) // Call after hook
if err != nil {
err = fmt.Errorf("pull: failed to download and extract layer %s: %w", desc.Digest, err)
logrus.Error(err)
Expand Down
25 changes: 19 additions & 6 deletions pkg/config/pull.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,29 @@ func (p *Pull) Validate() error {
}

// PullHooks is the hook events during the pull operation.
//
// Note: every retry attempt re-invokes BeforePullLayer / AfterPullLayer.
type PullHooks interface {
Comment thread
chlins marked this conversation as resolved.
// BeforePullLayer will execute before pulling the layer described as desc, will carry the manifest as well.
BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest)
// BeforePullLayer will execute before pulling the layer described as desc,
// will carry the manifest as well.
//
// If the hook returns skip=true, the backend will treat this layer as
// already satisfied and will NOT actually pull/extract it. The caller is
// responsible for ensuring the corresponding content already exists and
// matches the descriptor's digest. AfterPullLayer will still be invoked
// with skipped=true and a nil error.
BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) (skip bool)

// AfterPullLayer will execute after pulling the layer described as desc, the error will be nil if pulled successfully.
AfterPullLayer(desc ocispec.Descriptor, err error)
// AfterPullLayer will execute after pulling the layer described as desc.
// skipped indicates whether the layer was skipped by BeforePullLayer's
// decision. err will be nil if pulled (or skipped) successfully.
AfterPullLayer(desc ocispec.Descriptor, skipped bool, err error)
}

// emptyPullHook is the empty pull hook implementation with do nothing.
type emptyPullHook struct{}

func (emptyPullHook) BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) {}
func (emptyPullHook) AfterPullLayer(desc ocispec.Descriptor, err error) {}
func (emptyPullHook) BeforePullLayer(desc ocispec.Descriptor, manifest ocispec.Manifest) bool {
return false
}
func (emptyPullHook) AfterPullLayer(desc ocispec.Descriptor, skipped bool, err error) {}
Loading