Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@ bin/*
# Compiled test tools
/weights-gen

# Test directory for weights testing
# Test directories for weights testing
/test-weights/
/test-weights-example/
weights.lock

# Auto-:d version files from setuptools-scm
Expand Down
9 changes: 9 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ The main commands for working on the CLI are:
- `mise run test:go` - Runs all Go unit tests
- `go test ./pkg/...` - Runs tests directly with `go test`

### Building and running Go binaries

**Never run `go build` and leave a binary in the repo.** Stray binaries bloat the repo and get accidentally committed. Follow these rules:

- **To test execution**, use `go run ./cmd/<name>` — no binary is produced.
- **To verify compilation**, use `go build ./cmd/<name>` (without `-o`) — this still writes a binary to the working directory, so prefer `go vet ./cmd/<name>` for a compile check that produces no artifact.
- **If you must produce a binary** (e.g. for integration tests), write it to a temp directory and clean up: `go build -o "$(mktemp -d)/binary" ./cmd/<name>`.
- **For installable builds**, use `mise run build:cog` or `make install` — these have proper output paths.

## Working on the Python SDK
The Python SDK is developed in the `python/cog/` directory. It uses `uv` for virtual environments and `tox` for testing across multiple Python versions.

Expand Down
44 changes: 44 additions & 0 deletions cmd/cog-kong/build.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package main

import (
"context"

"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/docker/command"
"github.com/replicate/cog/pkg/model"
"github.com/replicate/cog/pkg/registry"
"github.com/replicate/cog/pkg/util/console"
)

// BuildCmd implements the "cog build" command.
type BuildCmd struct {
BuildFlags `embed:""`

Tag string `name:"tag" short:"t" help:"A name for the built image in the form 'repository:tag'."`
}

// Validate is called by Kong after parsing, before Run. It replaces Cobra's PreRunE.
func (cmd *BuildCmd) Validate() error {
return cmd.ValidateMutualExclusivity()
}

// Run executes the build command.
func (cmd *BuildCmd) Run(ctx context.Context, dockerClient command.Command, regClient registry.Client, src *model.Source) error {
imageName := src.Config.Image
if cmd.Tag != "" {
imageName = cmd.Tag
}
if imageName == "" {
imageName = config.DockerImageName(src.ProjectDir)
}

resolver := model.NewResolver(dockerClient, regClient)
m, err := resolver.Build(ctx, src, cmd.BuildOptions(imageName, nil))
if err != nil {
return err
}

console.Infof("\nImage built as %s", m.ImageRef())

return nil
}
38 changes: 38 additions & 0 deletions cmd/cog-kong/cli.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package main

import (
"context"

"github.com/alecthomas/kong"

"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/update"
"github.com/replicate/cog/pkg/util/console"
)

// Globals holds flags available to every command.
// The AfterApply hook replaces Cobra's PersistentPreRun.
type Globals struct {
Debug bool `name:"debug" short:"d" env:"COG_DEBUG" help:"Show debugging output."`
Registry string `name:"registry" default:"${registry_default}" env:"COG_REGISTRY_HOST" hidden:"" help:"Registry host."`
Profile bool `name:"profile" hidden:"" help:"Enable profiling."`
Version kong.VersionFlag `name:"version" short:"v" help:"Show version of Cog."`
}

// AfterApply runs after flag parsing, before the command's Run.
// This is the Kong equivalent of Cobra's PersistentPreRun.
func (g *Globals) AfterApply(ctx context.Context) error {
if g.Debug {
global.Debug = true
console.SetLevel(console.DebugLevel)
}
if g.Profile {
global.ProfilingEnabled = true
}
global.ReplicateRegistryHost = g.Registry

if err := update.DisplayAndCheckForRelease(ctx); err != nil {
console.Debugf("%s", err)
}
return nil
}
28 changes: 28 additions & 0 deletions cmd/cog-kong/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package main

import (
"context"

"github.com/replicate/cog/pkg/docker"
"github.com/replicate/cog/pkg/docker/command"
"github.com/replicate/cog/pkg/provider"
"github.com/replicate/cog/pkg/provider/setup"
"github.com/replicate/cog/pkg/registry"
)

// provideDockerClient creates a Docker client, binding to the command.Command interface.
// Registered as a singleton provider so all commands share one connection.
func provideDockerClient(ctx context.Context) (command.Command, error) {
return docker.NewClient(ctx)
}

// provideRegistryClient creates a registry client, binding to the registry.Client interface.
func provideRegistryClient() registry.Client {
return registry.NewRegistryClient()
}

// provideProviderRegistry creates a provider registry with all built-in providers.
// This replaces the setup.Init() global side-effect pattern used in the cobra CLI.
func provideProviderRegistry() *provider.Registry {
return setup.NewRegistry()
}
100 changes: 100 additions & 0 deletions cmd/cog-kong/flags.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package main

import (
"fmt"
"os"
"strings"

"github.com/replicate/cog/pkg/config"
"github.com/replicate/cog/pkg/model"
)

// ConfigFlag is an embeddable flag group for specifying the cog.yaml path.
// Any command that embeds ConfigFlag (directly or via BuildFlags) automatically
// gets a ProvideModelSource method discovered by Kong's DI system.
type ConfigFlag struct {
File string `name:"file" short:"f" default:"cog.yaml" help:"The name of the config file."`
}

// ProvideModelSource is discovered by Kong's DI system (Provide* convention).
// It loads the model source from the config file path specified by --file.
func (c *ConfigFlag) ProvideModelSource() (*model.Source, error) {
return model.NewSource(c.File)
}

// BuildFlags groups all flags shared across commands that build images.
// Embed this in any command struct that calls resolver.Build().
type BuildFlags struct {
ConfigFlag `embed:""`

NoCache bool `name:"no-cache" help:"Do not use cache when building the image."`
SeparateWeights bool `name:"separate-weights" help:"Separate model weights from code in image layers."`
Secrets []string `name:"secret" help:"Secrets to pass to the build environment in the form 'id=foo,src=/path/to/file'."`
Progress string `name:"progress" default:"${progress_default}" enum:"auto,plain,tty,quiet" help:"Set type of build progress output: ${enum}."`
UseCudaBaseImage string `name:"use-cuda-base-image" default:"auto" enum:"auto,true,false" help:"Use Nvidia CUDA base image, 'true' (default) or 'false' (use python base image)."`
UseCogBaseImage *bool `name:"use-cog-base-image" help:"Use pre-built Cog base image for faster cold boots."`
OpenAPISchema string `name:"openapi-schema" type:"existingfile" help:"Load OpenAPI schema from a file."`

// Hidden flags
Dockerfile string `name:"dockerfile" hidden:"" type:"existingfile" help:"Path to a Dockerfile. If set, cog will use this Dockerfile instead of generating one from cog.yaml."`
Timestamp int64 `name:"timestamp" hidden:"" default:"-1" help:"Number of seconds since Epoch to use for the build timestamp."`
Strip bool `name:"strip" hidden:"" help:"Whether to strip shared libraries for faster inference times."`
Precompile bool `name:"precompile" hidden:"" help:"Whether to precompile python files for faster load times."`
}

// AfterApply syncs parsed flag values to package-level globals that the build
// pipeline reads. This runs after Kong parses flags but before Run().
func (b *BuildFlags) AfterApply() error {
config.BuildSourceEpochTimestamp = b.Timestamp
return nil
}

// BuildOptions constructs a model.BuildOptions from the current flag values.
// The imageName and annotations parameters vary by caller (build vs push).
func (b *BuildFlags) BuildOptions(imageName string, annotations map[string]string) model.BuildOptions {
return model.BuildOptions{
ImageName: imageName,
Secrets: b.Secrets,
NoCache: b.NoCache,
SeparateWeights: b.SeparateWeights,
UseCudaBaseImage: b.UseCudaBaseImage,
ProgressOutput: b.Progress,
SchemaFile: b.OpenAPISchema,
DockerfileFile: b.Dockerfile,
UseCogBaseImage: b.UseCogBaseImage,
Strip: b.Strip,
Precompile: b.Precompile,
Annotations: annotations,
OCIIndex: model.OCIIndexEnabled(),
}
}

// ValidateMutualExclusivity ensures that at most one of --use-cog-base-image,
// --use-cuda-base-image, and --dockerfile is explicitly set.
func (b *BuildFlags) ValidateMutualExclusivity() error {
var flagsSet []string
if b.UseCogBaseImage != nil {
flagsSet = append(flagsSet, "--use-cog-base-image")
}
if b.UseCudaBaseImage != "auto" {
flagsSet = append(flagsSet, "--use-cuda-base-image")
}
if b.Dockerfile != "" {
flagsSet = append(flagsSet, "--dockerfile")
}
if len(flagsSet) > 1 {
return fmt.Errorf("The flags %s are mutually exclusive: you can only set one of them", strings.Join(flagsSet, " and "))
}
return nil
}

// progressDefault returns the default progress output based on environment.
func progressDefault() string {
if v := os.Getenv("BUILDKIT_PROGRESS"); v != "" {
return v
}
if os.Getenv("TERM") == "dumb" {
return "plain"
}
return "auto"
}
108 changes: 108 additions & 0 deletions cmd/cog-kong/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package main

import (
"context"
"errors"
"fmt"
"os"
"os/signal"
"strings"
"syscall"

"github.com/alecthomas/kong"

"github.com/replicate/cog/pkg/global"
"github.com/replicate/cog/pkg/util/console"
)

// Build-time variables. Initialized from global defaults; overridden by -ldflags at build time.
var (
version = global.Version
commit = global.Commit
buildTime = global.BuildTime
)

// CLI is the root command struct. Kong parses into this.
type CLI struct {
Globals

Build BuildCmd `cmd:"" help:"Build an image from cog.yaml."`
Push PushCmd `cmd:"" help:"Build and push model in current directory to a Docker registry."`
}

func main() {
ctx, cancel := newCancellationContext()

var cli CLI

initOpts := []kong.Option{
// CLI metadata and variable interpolation for struct tags
kong.Name("cog"),
kong.Description("Containers for machine learning."),
kong.Vars{
"version": fmt.Sprintf("cog version %s (built %s)", version, buildTime),
"commit": commit,
"progress_default": progressDefault(),
"registry_default": global.DefaultReplicateRegistryHost,
},
kong.UsageOnError(),

// bindings for lazily injecting dependencies into Run() methods
kong.BindTo(ctx, (*context.Context)(nil)),
kong.BindSingletonProvider(provideDockerClient),
kong.BindToProvider(provideRegistryClient),
kong.BindSingletonProvider(provideProviderRegistry),
}

parser, err := kong.New(&cli, initOpts...)
if err != nil {
// Fatal error creating the parser — this is a bug, so panic to get a stack trace.
panic(err)
}

kctx, err := parser.Parse(os.Args[1:])

// Unable to parse input to a valid command
if err != nil {
// If the command isn't runnable (i.e. `cog`) just print help and exit 0 (matches Cobra behavior).
var parseErr *kong.ParseError
// Exit code 80 is kong's internal code for "no runnable command selected" (e.g. bare `cog` with no subcommand).
if errors.As(err, &parseErr) && parseErr.ExitCode() == 80 && strings.HasPrefix(parseErr.Error(), "expected") {
_ = parseErr.Context.PrintUsage(false)
return
}

// otherwise it's a real parse error (e.g. unexpected command or flag), so print the error and exit non-zero.
parser.FatalIfErrorf(err)
}

err = kctx.Run()
cancel()
// command returned an error. Print and exit non-zero.
if err != nil {
parser.FatalIfErrorf(err)
}
}

func newCancellationContext() (context.Context, context.CancelFunc) {
// First signal cancels the context, giving commands a chance to clean up.
// Second signal force-exits immediately.
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)

go func() {
// Block until the first signal cancels the context.
<-ctx.Done()

// Now register for the second signal after the first one has been received.
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)

console.Debugf("Shutting down. Signal again to force quit.")

<-sig
console.Warnf("Forced exit")
os.Exit(1)
}()

return ctx, cancel
}
Loading