Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ func (cfg *RawConfig) ParameterKeyExcludeModelWeights() string {
return cfg.ServiceName + "/exclude-model-weights"
}

func (cfg *RawConfig) ParameterKeyExcludeFiles() string {
return cfg.ServiceName + "/exclude-file-patterns"
}

// /var/lib/dragonfly/model-csi/volumes
func (cfg *RawConfig) GetVolumesDir() string {
return filepath.Join(cfg.RootDir, "volumes")
Expand Down
4 changes: 2 additions & 2 deletions pkg/server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type mockPuller struct {
}

func (puller *mockPuller) Pull(
ctx context.Context, reference, targetDir string, excludeModelWeights bool,
ctx context.Context, reference, targetDir string, excludeModelWeights bool, excludeFilePatterns []string,
) error {
if err := os.MkdirAll(targetDir, 0755); err != nil {
return err
Expand Down Expand Up @@ -560,7 +560,7 @@ func TestServer(t *testing.T) {
cfg.Get().PullConfig.ProxyURL = ""
service.CacheScanInterval = 1 * time.Second

service.NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *service.DiskQuotaChecker) service.Puller {
service.NewPuller = func(ctx context.Context, pullCfg *config.PullConfig, hook *status.Hook, diskQuotaChecker *service.DiskQuotaChecker, excludeFilePatterns []string) service.Puller {
return &mockPuller{
pullCfg: pullCfg,
duration: time.Second * 2,
Expand Down
14 changes: 12 additions & 2 deletions pkg/service/controller_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
}
}

excludeFilePatternsParam := strings.TrimSpace(parameters[s.cfg.Get().ParameterKeyExcludeFiles()])
var excludeFilePatterns []string
if excludeFilePatternsParam != "" {
excludeFilePatterns = strings.Split(excludeFilePatternsParam, ",")
if len(excludeFilePatterns) == 0 {
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: must be valid comma-separated pattern", s.cfg.Get().ParameterKeyExcludeFiles())
}

}

parentSpan := trace.SpanFromContext(ctx)
parentSpan.SetAttributes(attribute.String("volume_name", volumeName))
parentSpan.SetAttributes(attribute.String("reference", modelReference))
Expand All @@ -78,7 +88,7 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
startedAt := time.Now()
ctx, span := tracing.Tracer.Start(ctx, "PullModel")
span.SetAttributes(attribute.String("model_dir", modelDir))
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, "", modelReference, modelDir, checkDiskQuota, excludeModelWeights); err != nil {
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, "", modelReference, modelDir, checkDiskQuota, excludeModelWeights, excludeFilePatterns); err != nil {
span.SetStatus(otelCodes.Error, "failed to pull model")
span.RecordError(err)
span.End()
Expand Down Expand Up @@ -111,7 +121,7 @@ func (s *Service) localCreateVolume(ctx context.Context, req *csi.CreateVolumeRe
startedAt := time.Now()
ctx, span := tracing.Tracer.Start(ctx, "PullModel")
span.SetAttributes(attribute.String("model_dir", modelDir))
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, mountID, modelReference, modelDir, checkDiskQuota, excludeModelWeights); err != nil {
if err := s.worker.PullModel(ctx, isStaticVolume, volumeName, mountID, modelReference, modelDir, checkDiskQuota, excludeModelWeights, excludeFilePatterns); err != nil {
span.SetStatus(otelCodes.Error, "failed to pull model")
span.RecordError(err)
span.End()
Expand Down
11 changes: 11 additions & 0 deletions pkg/service/dynamic_server_handler.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"encoding/json"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -86,6 +87,15 @@ func (h *DynamicServerHandler) CreateVolume(c echo.Context) error {
})
}

excludeFilesJSON := "[]"
if len(req.ExcludeFilePatterns) > 0 {
jsonBytes, err := json.Marshal(req.ExcludeFilePatterns)
if err != nil {
return handleError(c, fmt.Errorf("marshal exclude_file_patterns: %w", err))
}
excludeFilesJSON = string(jsonBytes)
}

_, err := h.svc.CreateVolume(c.Request().Context(), &csi.CreateVolumeRequest{
Name: volumeName,
Parameters: map[string]string{
Expand All @@ -94,6 +104,7 @@ func (h *DynamicServerHandler) CreateVolume(c echo.Context) error {
h.cfg.Get().ParameterKeyMountID(): req.MountID,
h.cfg.Get().ParameterKeyCheckDiskQuota(): strconv.FormatBool(req.CheckDiskQuota),
h.cfg.Get().ParameterKeyExcludeModelWeights(): strconv.FormatBool(req.ExcludeModelWeights),
h.cfg.Get().ParameterKeyExcludeFiles(): excludeFilesJSON,
},
})
if err != nil {
Expand Down
19 changes: 18 additions & 1 deletion pkg/service/node.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"encoding/json"
"path/filepath"
"strconv"
"strings"
Expand Down Expand Up @@ -104,8 +105,24 @@ func (s *Service) nodePublishVolume(
}
}

excludeFilePatternsParam := volumeAttributes[s.cfg.Get().ParameterKeyExcludeFiles()]
var excludeFilePatterns []string
if excludeFilePatternsParam != "" {
if err := json.Unmarshal([]byte(excludeFilePatternsParam), &excludeFilePatterns); err != nil {
Comment thread
aagumin marked this conversation as resolved.
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: must be valid JSON array: %v", s.cfg.Get().ParameterKeyExcludeFiles(), err)
}
for _, p := range excludeFilePatterns {
if strings.HasPrefix(p, "/") && len(p) > 1 {
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: absolute paths not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p)
}
if strings.Contains(p, "..") {
return nil, isStaticVolume, status.Errorf(codes.InvalidArgument, "invalid parameter:%s: parent directory reference not allowed: %s", s.cfg.Get().ParameterKeyExcludeFiles(), p)
}
}
}

logger.WithContext(ctx).Infof("publishing static inline volume: %s", staticInlineModelReference)
resp, err := s.nodePublishVolumeStaticInlineVolume(ctx, volumeID, targetPath, staticInlineModelReference, excludeModelWeights)
resp, err := s.nodePublishVolumeStaticInlineVolume(ctx, volumeID, targetPath, staticInlineModelReference, excludeModelWeights, excludeFilePatterns)
return resp, isStaticVolume, err
}

Expand Down
4 changes: 2 additions & 2 deletions pkg/service/node_static_inline.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ import (
"google.golang.org/grpc/status"
)

func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool) (*csi.NodePublishVolumeResponse, error) {
func (s *Service) nodePublishVolumeStaticInlineVolume(ctx context.Context, volumeName, targetPath, reference string, excludeModelWeights bool, excludeFilePatterns []string) (*csi.NodePublishVolumeResponse, error) {
modelDir := s.cfg.Get().GetModelDir(volumeName)

startedAt := time.Now()
if err := s.worker.PullModel(ctx, true, volumeName, "", reference, modelDir, false, excludeModelWeights); err != nil {
if err := s.worker.PullModel(ctx, true, volumeName, "", reference, modelDir, false, excludeModelWeights, excludeFilePatterns); err != nil {
return nil, status.Error(codes.Internal, errors.Wrap(err, "pull model").Error())
}
duration := time.Since(startedAt)
Expand Down
167 changes: 167 additions & 0 deletions pkg/service/patterns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
package service

import (
"io"
"os"
"path/filepath"
"sort"
"strings"

gitignore "github.com/go-git/go-git/v5/plumbing/format/gitignore"
"github.com/modelpack/model-csi-driver/pkg/logger"
"github.com/pkg/errors"
)

// FilePatternMatcher wraps gitignore pattern matching functionality
type FilePatternMatcher struct {
matcher gitignore.Matcher
patterns []string
}

// NewFilePatternMatcher creates a new pattern matcher from a list of gitignore-compatible patterns
func NewFilePatternMatcher(patterns []string) (*FilePatternMatcher, error) {
// Create gitignore matcher from patterns
// Parse each string pattern into gitignore.Pattern
var gitPatterns []gitignore.Pattern
for _, p := range patterns {
gitPatterns = append(gitPatterns, gitignore.ParsePattern(p, nil))
}
matcher := gitignore.NewMatcher(gitPatterns)

return &FilePatternMatcher{
matcher: matcher,
patterns: patterns,
}, nil
}

// Match returns true if the given path matches any of the exclusion patterns
func (m *FilePatternMatcher) Match(path string) bool {
// gitignore matcher expects paths in forward-slash format
// and uses a slice of strings for path components
path = filepath.ToSlash(path)
pathParts := strings.Split(path, "/")
isDir := strings.HasSuffix(path, "/")
return m.matcher.Match(pathParts, isDir)
}

// Excludes returns true if any exclusion patterns are defined
func (m *FilePatternMatcher) Excludes() bool {
return len(m.patterns) > 0
}

// filterFilesByPatterns walks the target directory and removes files matching the exclusion patterns
// Returns a list of excluded file paths (relative to targetDir)
func filterFilesByPatterns(targetDir string, matcher *FilePatternMatcher) ([]string, error) {
excludedFiles := []string{}

// First pass: identify and remove matched files
err := filepath.Walk(targetDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

// Skip the target directory itself
if path == targetDir {
return nil
}

// Get relative path for pattern matching
relPath, err := filepath.Rel(targetDir, path)
if err != nil {
return errors.Wrap(err, "get relative path")
}

// Check if file/directory matches exclusion pattern
if matcher.Match(relPath) {
if !info.IsDir() {
logger.Logger().Infof("Excluding file: %s", relPath)
excludedFiles = append(excludedFiles, relPath)

// Remove the file
if err := os.Remove(path); err != nil {
return errors.Wrapf(err, "remove excluded file: %s", relPath)
}
}
}

return nil
})

if err != nil {
return nil, errors.Wrap(err, "walk directory for pattern matching")
}

// Second pass: remove empty directories
removeEmptyDirectories(targetDir, matcher)

// Sort excluded files for consistent logging
sort.Strings(excludedFiles)

logger.Logger().Infof("Excluded %d file(s) matching patterns", len(excludedFiles))

return excludedFiles, nil
}

// removeEmptyDirectories removes empty directories that were created after file removal
func removeEmptyDirectories(targetDir string, matcher *FilePatternMatcher) {
dirsToRemove := []string{}

// First, find all empty directories
err := filepath.Walk(targetDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil // Continue on error
}

if info.IsDir() && path != targetDir {
isEmpty, err := isDirEmpty(path)
if err != nil {
logger.Logger().WithError(err).Warnf("Failed to check if directory is empty: %s", path)
return nil
}
if isEmpty {
dirsToRemove = append(dirsToRemove, path)
}
}

return nil
})

if err != nil {
logger.Logger().WithError(err).Warn("Failed to walk directories for cleanup")
return
}

// Remove empty directories in reverse order (deepest first)
for i := len(dirsToRemove) - 1; i >= 0; i-- {
dir := dirsToRemove[i]
if err := os.Remove(dir); err != nil {
logger.Logger().WithError(err).Warnf("Failed to remove empty directory: %s", dir)
} else {
relPath, _ := filepath.Rel(targetDir, dir)
logger.Logger().Infof("Removed empty directory: %s", relPath)
}
}
}

// isDirEmpty checks if a directory is empty
func isDirEmpty(dir string) (bool, error) {
f, err := os.Open(dir)
if err != nil {
return false, err
}
defer func(f *os.File) {
err = f.Close()
if err != nil {
return
}
}(f)

_, err = f.Readdirnames(1)
if err == nil {
return false, nil // Directory is not empty
}
if err == io.EOF {
return true, nil // Directory is empty
}
return false, err // Error reading directory
}
Loading
Loading