Skip to content
Open
58 changes: 45 additions & 13 deletions pkg/cmd/kitimport/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,39 @@ import (
)

const (
shortDesc = `Import a model from HuggingFace`
longDesc = `Download a repository from HuggingFace and package it as a ModelKit.
shortDesc = `Import a model from HuggingFace or MLFlow`
longDesc = `Download a repository or MLFlow run and package it as a ModelKit.

The repository can be specified either via a repository (e.g. myorg/myrepo) or
with a full URL (https://huggingface.co/myorg/myrepo). The repository will be
downloaded to a temporary directory and be packaged using a generated Kitfile.

MLFlow runs can be imported using the mlflow:// URI scheme:

mlflow://[tracking-host/]experiments/{exp_id}/runs/{run_id}
mlflow://[tracking-host/]runs/{run_id}

The tracking server is determined by (in order of precedence):
1. The host embedded in the mlflow:// URI
2. The MLFLOW_TRACKING_URI environment variable
3. The default http://localhost:5000

Authentication tokens can be set with the --token flag or via the
MLFLOW_TRACKING_TOKEN environment variable. Only FINISHED runs can be imported.

In interactive settings, this command will read the EDITOR environment variable
to determine which editor should be used for editing the Kitfile.

This command supports multiple ways of downloading files from the remote
repository. The tool used can be specified using the --tool flag with one of the
options below:

--tool=hf : Download files using the Huggingface API. Requires REPOSITORY to
be a Huggingface repository. This is the default for Huggingface
repositories
--tool=git : Download files using Git and Git LFS. Works for any Git
repository but requires that Git and Git LFS are installed.
--tool=hf : Download files using the Huggingface API. Requires REPOSITORY
to be a Huggingface repository. This is the default for
Huggingface repositories.
--tool=git : Download files using Git and Git LFS. Works for any Git
repository but requires Git and Git LFS to be installed.
--tool=mlflow : Import artifacts from an MLFlow run via the mlflow:// URI.

By default, Kit will automatically select the tool based on the provided
REPOSITORY.`
Expand All @@ -67,7 +81,16 @@ kit import myorg/myrepo --ref v1.0.0
kit import myorg/myrepo --tag myrepository:mytag

# Download repository and pack it using an existing Kitfile
kit import myorg/myrepo --file ./path/to/Kitfile`
kit import myorg/myrepo --file ./path/to/Kitfile

# Import an MLFlow run (tracking server from MLFLOW_TRACKING_URI env var)
kit import mlflow://experiments/42/runs/abc123 -t mymodel:v1

# Import an MLFlow run with explicit tracking server host
kit import mlflow://mlflow.company.com/experiments/42/runs/abc123 -t mymodel:v1

# Import using run ID only (short form)
kit import mlflow://runs/abc123 -t mymodel:v1`
)

type importOptions struct {
Expand Down Expand Up @@ -98,8 +121,8 @@ func ImportCommand() *cobra.Command {
cmd.Flags().StringVar(&opts.token, "token", "", "Token to use for authenticating with repository")
cmd.Flags().StringVarP(&opts.tag, "tag", "t", "", "Tag for the ModelKit (default is '[repository]:latest')")
cmd.Flags().StringVarP(&opts.kitfilePath, "file", "f", "", "Path to Kitfile to use for packing (use '-' to read from standard input)")
cmd.Flags().StringVar(&opts.downloadTool, "tool", "", "Tool to use for downloading files: options are 'git' and 'hf' (default: detect based on repository)")
cmd.Flags().IntVar(&opts.concurrency, "concurrency", 5, "Maximum number of simultaneous downloads (for huggingface)")
cmd.Flags().StringVar(&opts.downloadTool, "tool", "", "Tool to use for downloading files: options are 'git', 'hf', and 'mlflow' (default: detect based on repository)")
cmd.Flags().IntVar(&opts.concurrency, "concurrency", 5, "Maximum number of simultaneous downloads (for huggingface and mlflow)")
cmd.Flags().SortFlags = false
return cmd
}
Expand Down Expand Up @@ -134,13 +157,15 @@ func (opts *importOptions) complete(ctx context.Context, args []string) error {
var tagRepo string
if repo, _, err := hf.ParseHuggingFaceRepo(opts.repo); err == nil {
tagRepo = repo
} else if strings.HasPrefix(opts.repo, "mlflow://") {
tagRepo = extractTagFromMLFlowURI(opts.repo)
} else {
repo, err := extractRepoFromURL(opts.repo)
var err error
tagRepo, err = extractRepoFromURL(opts.repo)
if err != nil {
output.Errorf("Could not generate tag from URL: %s", err)
return fmt.Errorf("use flag --tag to set a tag for ModelKit")
}
tagRepo = repo
}
tagRepo = strings.ToLower(tagRepo)
opts.tag = fmt.Sprintf("%s:latest", tagRepo)
Expand All @@ -157,7 +182,7 @@ func (opts *importOptions) complete(ctx context.Context, args []string) error {
}
opts.modelKitRef = ref

validTools := []string{"git", "hf"}
validTools := []string{"git", "hf", "mlflow"}
if opts.downloadTool != "" && !slices.Contains(validTools, opts.downloadTool) {
return fmt.Errorf("invalid value for --tool flag. Valid options are: %s", strings.Join(validTools, ", "))
}
Expand All @@ -174,7 +199,14 @@ func getImporter(opts *importOptions) (func(context.Context, *importOptions) err
return importUsingHF, nil
case "git":
return importUsingGit, nil
case "mlflow":
return importUsingMLFlow, nil
default:
// Auto-detect based on URI scheme
if strings.HasPrefix(opts.repo, "mlflow://") {
return importUsingMLFlow, nil
}

repoUrl, err := url.Parse(opts.repo)
if err != nil {
return nil, fmt.Errorf("failed to parse %s: %w", opts.repo, err)
Expand Down
Loading
Loading