Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 232 additions & 29 deletions experimental/ssh/internal/client/client.go

Large diffs are not rendered by default.

86 changes: 80 additions & 6 deletions experimental/ssh/internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/databricks/cli/experimental/ssh/internal/client"
"github.com/databricks/cli/libs/telemetry/protos"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand All @@ -18,9 +19,9 @@ func TestValidate(t *testing.T) {
wantErr string
}{
{
name: "no cluster or connection name",
name: "no cluster or connection name or accelerator",
opts: client.ClientOptions{},
wantErr: "please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)",
wantErr: "please provide --cluster or --accelerator flag",
},
{
name: "proxy mode skips cluster/name check",
Expand All @@ -31,9 +32,13 @@ func TestValidate(t *testing.T) {
opts: client.ClientOptions{ClusterID: "abc-123"},
},
{
name: "accelerator without connection name",
name: "accelerator with cluster ID",
opts: client.ClientOptions{ClusterID: "abc-123", Accelerator: "GPU_1xA10"},
wantErr: "--accelerator flag can only be used with serverless compute (--name flag)",
wantErr: "--accelerator flag can only be used with serverless compute, not with --cluster",
},
{
name: "accelerator only (auto-generate session name)",
opts: client.ClientOptions{Accelerator: "GPU_1xA10"},
},
{
name: "connection name without accelerator",
Expand All @@ -55,8 +60,9 @@ func TestValidate(t *testing.T) {
opts: client.ClientOptions{ConnectionName: "my-conn_1", Accelerator: "GPU_1xA10"},
},
{
name: "both cluster ID and connection name",
opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn", Accelerator: "GPU_1xA10"},
name: "both cluster ID and connection name (no accelerator)",
opts: client.ClientOptions{ClusterID: "abc-123", ConnectionName: "my-conn"},
wantErr: "--name flag requires --accelerator to be set (for now we only support serverless GPU compute)",
},
{
name: "proxy mode with invalid connection name",
Expand Down Expand Up @@ -164,3 +170,71 @@ func TestToProxyCommand(t *testing.T) {
})
}
}

func TestBuildTelemetryEvent(t *testing.T) {
tests := []struct {
name string
opts client.ClientOptions
want *protos.SshTunnelEvent
}{
{
name: "dedicated cluster with SSH client",
opts: client.ClientOptions{
ClusterID: "abc-123",
AutoStartCluster: true,
},
want: &protos.SshTunnelEvent{
ComputeType: protos.SshTunnelComputeTypeDedicated,
ClientMode: protos.SshTunnelClientModeSSH,
AutoStartCluster: true,
},
},
{
name: "serverless with IDE",
opts: client.ClientOptions{
ConnectionName: "my-conn",
Accelerator: "GPU_1xA10",
IDE: "vscode",
},
want: &protos.SshTunnelEvent{
ComputeType: protos.SshTunnelComputeTypeServerless,
ClientMode: protos.SshTunnelClientModeIDE,
AcceleratorType: "GPU_1xA10",
IdeType: "vscode",
},
},
{
name: "proxy mode with metadata (reconnect)",
opts: client.ClientOptions{
ClusterID: "abc-123",
ProxyMode: true,
ServerMetadata: "user,2222,abc-123",
},
want: &protos.SshTunnelEvent{
ComputeType: protos.SshTunnelComputeTypeDedicated,
ClientMode: protos.SshTunnelClientModeProxy,
IsReconnect: true,
},
},
{
name: "serverless proxy mode",
opts: client.ClientOptions{
ConnectionName: "my-conn",
Accelerator: "GPU_8xH100",
ProxyMode: true,
},
want: &protos.SshTunnelEvent{
ComputeType: protos.SshTunnelComputeTypeServerless,
ClientMode: protos.SshTunnelClientModeProxy,
AcceleratorType: "GPU_8xH100",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := client.BuildTelemetryEvent(tt.opts)
assert.Equal(t, tt.want, got)
})
}
}
12 changes: 6 additions & 6 deletions experimental/ssh/internal/client/releases.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
"strings"

"github.com/databricks/cli/experimental/ssh/internal/workspace"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/filer"
"github.com/databricks/cli/libs/log"
"github.com/databricks/databricks-sdk-go"
)

Expand Down Expand Up @@ -48,7 +48,7 @@ func uploadReleases(ctx context.Context, workspaceFiler filer.Filer, getRelease

_, err := workspaceFiler.Stat(ctx, remoteBinaryPath)
if err == nil {
cmdio.LogString(ctx, fmt.Sprintf("File %s already exists in the workspace, skipping upload", remoteBinaryPath))
log.Infof(ctx, "File %s already exists in the workspace, skipping upload", remoteBinaryPath)
continue
} else if !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("failed to check if file %s exists in workspace: %w", remoteBinaryPath, err)
Expand All @@ -60,14 +60,14 @@ func uploadReleases(ctx context.Context, workspaceFiler filer.Filer, getRelease
}
defer releaseReader.Close()

cmdio.LogString(ctx, fmt.Sprintf("Uploading %s to the workspace", fileName))
log.Infof(ctx, "Uploading %s to the workspace", fileName)
// workspace-files/import-file API will automatically unzip the payload,
// producing the filerRoot/remoteSubFolder/*archive-contents* structure, with 'databricks' binary inside.
err = workspaceFiler.Write(ctx, remoteArchivePath, releaseReader, filer.OverwriteIfExists, filer.CreateParentDirectories)
if err != nil {
return fmt.Errorf("failed to upload file %s to workspace: %w", remoteArchivePath, err)
}
cmdio.LogString(ctx, fmt.Sprintf("Successfully uploaded %s to workspace", remoteBinaryPath))
log.Infof(ctx, "Successfully uploaded %s to workspace", remoteBinaryPath)
}

return nil
Expand All @@ -81,7 +81,7 @@ func getReleaseName(architecture, version string) string {
}

func getLocalRelease(ctx context.Context, architecture, version, releasesDir string) (io.ReadCloser, error) {
cmdio.LogString(ctx, "Looking for CLI releases in directory: "+releasesDir)
log.Infof(ctx, "Looking for CLI releases in directory: %s", releasesDir)
releaseName := getReleaseName(architecture, version)
releasePath := filepath.Join(releasesDir, releaseName)
file, err := os.Open(releasePath)
Expand All @@ -95,7 +95,7 @@ func getGithubRelease(ctx context.Context, architecture, version, releasesDir st
// TODO: download and check databricks_cli_<version>_SHA256SUMS
fileName := getReleaseName(architecture, version)
downloadURL := fmt.Sprintf("https://github.com/databricks/cli/releases/download/v%s/%s", version, fileName)
cmdio.LogString(ctx, fmt.Sprintf("Downloading %s from %s", fileName, downloadURL))
log.Infof(ctx, "Downloading %s from %s", fileName, downloadURL)

resp, err := http.Get(downloadURL)
if err != nil {
Expand Down
6 changes: 3 additions & 3 deletions experimental/ssh/internal/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ import (
"io"
"time"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/log"
"golang.org/x/sync/errgroup"
)

func RunClientProxy(ctx context.Context, src io.ReadCloser, dst io.Writer, requestHandoverTick func() <-chan time.Time, createConn createWebsocketConnectionFunc) error {
proxy := newProxyConnection(createConn)
cmdio.LogString(ctx, "Establishing SSH proxy connection...")
log.Infof(ctx, "Establishing SSH proxy connection...")
g, gCtx := errgroup.WithContext(ctx)
if err := proxy.connect(gCtx); err != nil {
return fmt.Errorf("failed to connect to proxy: %w", err)
}
defer proxy.close()
cmdio.LogString(ctx, "SSH proxy connection established")
log.Infof(ctx, "SSH proxy connection established")

g.Go(func() error {
for {
Expand Down
28 changes: 28 additions & 0 deletions experimental/ssh/internal/sessions/namegen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package sessions

import (
"crypto/rand"
"encoding/hex"
"strings"
"time"
)

// acceleratorPrefixes maps known accelerator types to short human-readable prefixes.
var acceleratorPrefixes = map[string]string{
"GPU_1xA10": "gpu-a10",
"GPU_8xH100": "gpu-h100",
}

// GenerateSessionName creates a human-readable session name from the accelerator type.
// Format: <prefix>-<random_hex>, e.g. "gpu-a10-f3a2b1c0".
func GenerateSessionName(accelerator string) string {
prefix, ok := acceleratorPrefixes[accelerator]
if !ok {
prefix = strings.ToLower(strings.ReplaceAll(accelerator, "_", "-"))
}

date := time.Now().Format("20060102")
b := make([]byte, 3)
_, _ = rand.Read(b)
return "databricks-" + prefix + "-" + date + "-" + hex.EncodeToString(b)
}
147 changes: 147 additions & 0 deletions experimental/ssh/internal/sessions/sessions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package sessions

import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"

"github.com/databricks/cli/libs/env"
)

const (
stateFileName = "ssh-tunnel-sessions.json"

// Sessions older than this are considered expired and cleaned up automatically.
sessionMaxAge = 24 * time.Hour
)

// Session represents a tracked SSH tunnel session.
type Session struct {
Name string `json:"name"`
Accelerator string `json:"accelerator"`
WorkspaceHost string `json:"workspace_host"`
CreatedAt time.Time `json:"created_at"`
ClusterID string `json:"cluster_id,omitempty"`
}

// SessionStore holds all tracked sessions.
type SessionStore struct {
Sessions []Session `json:"sessions"`
}

func getStateFilePath(ctx context.Context) (string, error) {
homeDir, err := env.UserHomeDir(ctx)
if err != nil {
return "", fmt.Errorf("failed to get home directory: %w", err)
}
return filepath.Join(homeDir, ".databricks", stateFileName), nil
}

// Load reads the session store from disk. Returns an empty store if the file does not exist.
func Load(ctx context.Context) (*SessionStore, error) {
path, err := getStateFilePath(ctx)
if err != nil {
return nil, err
}

data, err := os.ReadFile(path)
if os.IsNotExist(err) {
return &SessionStore{}, nil
}
if err != nil {
return nil, fmt.Errorf("failed to read session state file: %w", err)
}

var store SessionStore
if err := json.Unmarshal(data, &store); err != nil {
return nil, fmt.Errorf("failed to parse session state file: %w", err)
}
return &store, nil
}

// Save writes the session store to disk atomically.
func Save(ctx context.Context, store *SessionStore) error {
path, err := getStateFilePath(ctx)
if err != nil {
return err
}

if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("failed to create state directory: %w", err)
}

data, err := json.MarshalIndent(store, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal session state: %w", err)
}

// Atomic write: write to temp file, then rename.
tmpPath := path + ".tmp"
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
return fmt.Errorf("failed to write session state file: %w", err)
}
if err := os.Rename(tmpPath, path); err != nil {
return fmt.Errorf("failed to rename session state file: %w", err)
}
return nil
}

// Add persists a new session to the store, replacing any existing session with the same name.
func Add(ctx context.Context, s Session) error {
store, err := Load(ctx)
if err != nil {
return err
}

// Replace existing session with the same name.
found := false
for i, existing := range store.Sessions {
if existing.Name == s.Name {
store.Sessions[i] = s
found = true
break
}
}
if !found {
store.Sessions = append(store.Sessions, s)
}

return Save(ctx, store)
}

// Remove deletes a session by name.
func Remove(ctx context.Context, name string) error {
store, err := Load(ctx)
if err != nil {
return err
}

filtered := store.Sessions[:0]
for _, s := range store.Sessions {
if s.Name != name {
filtered = append(filtered, s)
}
}
store.Sessions = filtered
return Save(ctx, store)
}

// FindMatching returns non-expired sessions that match the given workspace host and accelerator.
func FindMatching(ctx context.Context, workspaceHost, accelerator string) ([]Session, error) {
store, err := Load(ctx)
if err != nil {
return nil, err
}

cutoff := time.Now().Add(-sessionMaxAge)
var result []Session
for _, s := range store.Sessions {
if s.WorkspaceHost == workspaceHost && s.Accelerator == accelerator && s.CreatedAt.After(cutoff) {
result = append(result, s)
}
}
return result, nil
}
Loading
Loading