Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions experimental/ssh/cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ For serverless compute:
var connectionName string
var accelerator string
var proxyMode bool
var ide string
var serverMetadata string
var shutdownDelay time.Duration
var maxClients int
Expand All @@ -42,8 +43,9 @@ For serverless compute:
var liteswap string

cmd.Flags().StringVar(&clusterID, "cluster", "", "Databricks cluster ID (for dedicated clusters)")
cmd.Flags().StringVar(&connectionName, "name", "", "Connection name (for serverless compute)")
cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type for serverless compute (GPU_1xA10 or GPU_8xH100)")
cmd.Flags().StringVar(&connectionName, "name", "", "Connection name")
cmd.Flags().StringVar(&accelerator, "accelerator", "", "GPU accelerator type (GPU_1xA10 or GPU_8xH100)")
cmd.Flags().StringVar(&ide, "ide", "", "Open remote IDE window (vscode or cursor)")
cmd.Flags().DurationVar(&shutdownDelay, "shutdown-delay", defaultShutdownDelay, "Delay before shutting down the server after the last client disconnects")
cmd.Flags().IntVar(&maxClients, "max-clients", defaultMaxClients, "Maximum number of SSH clients")
cmd.Flags().BoolVar(&autoStartCluster, "auto-start-cluster", true, "Automatically start the cluster if it is not running")
Expand Down Expand Up @@ -80,7 +82,7 @@ For serverless compute:
wsClient := cmdctx.WorkspaceClient(ctx)

if !proxyMode && clusterID == "" && connectionName == "" {
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the serverless connection name")
return errors.New("please provide --cluster flag with the cluster ID, or --name flag with the connection name (for serverless compute)")
}

if accelerator != "" && connectionName == "" {
Expand All @@ -89,7 +91,7 @@ For serverless compute:

// Remove when we add support for serverless CPU
if connectionName != "" && accelerator == "" {
return errors.New("--name flag requires --accelerator to be set (e.g. for now we only support serverless GPU compute)")
return errors.New("--name flag requires --accelerator to be set (for now we only support serverless GPU compute)")
}

// TODO: validate connectionName if provided
Expand All @@ -100,6 +102,7 @@ For serverless compute:
ConnectionName: connectionName,
Accelerator: accelerator,
ProxyMode: proxyMode,
IDE: ide,
ServerMetadata: serverMetadata,
ShutdownDelay: shutdownDelay,
MaxClients: maxClients,
Expand Down
21 changes: 17 additions & 4 deletions experimental/ssh/cmd/setup.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package ssh

import (
"fmt"
"time"

"github.com/databricks/cli/cmd/root"
"github.com/databricks/cli/experimental/ssh/internal/client"
"github.com/databricks/cli/experimental/ssh/internal/setup"
"github.com/databricks/cli/libs/cmdctx"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -43,16 +45,27 @@ an SSH host configuration to your SSH config file.

cmd.RunE = func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
client := cmdctx.WorkspaceClient(ctx)
opts := setup.SetupOptions{
wsClient := cmdctx.WorkspaceClient(ctx)
setupOpts := setup.SetupOptions{
HostName: hostName,
ClusterID: clusterID,
AutoStartCluster: autoStartCluster,
SSHConfigPath: sshConfigPath,
ShutdownDelay: shutdownDelay,
Profile: client.Config.Profile,
Profile: wsClient.Config.Profile,
}
return setup.Setup(ctx, client, opts)
clientOpts := client.ClientOptions{
ClusterID: setupOpts.ClusterID,
AutoStartCluster: setupOpts.AutoStartCluster,
ShutdownDelay: setupOpts.ShutdownDelay,
Profile: setupOpts.Profile,
}
proxyCommand, err := clientOpts.ToProxyCommand()
if err != nil {
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
}
setupOpts.ProxyCommand = proxyCommand
return setup.Setup(ctx, wsClient, setupOpts)
}

return cmd
Expand Down
94 changes: 92 additions & 2 deletions experimental/ssh/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/databricks/cli/experimental/ssh/internal/keys"
"github.com/databricks/cli/experimental/ssh/internal/proxy"
"github.com/databricks/cli/experimental/ssh/internal/sshconfig"
sshWorkspace "github.com/databricks/cli/experimental/ssh/internal/workspace"
"github.com/databricks/cli/internal/build"
"github.com/databricks/cli/libs/cmdio"
Expand Down Expand Up @@ -55,6 +56,8 @@ type ClientOptions struct {
// to the cluster and proxy all traffic through stdin/stdout.
// In the non proxy mode the CLI spawns an ssh client with the ProxyCommand config.
ProxyMode bool
// Open remote IDE window with a specific ssh config (empty, 'vscode', or 'cursor')
IDE string
// Expected format: "<user_name>,<port>,<cluster_id>".
// If present, the CLI won't attempt to start the server.
ServerMetadata string
Expand Down Expand Up @@ -168,8 +171,7 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt
}

// Only check cluster state for dedicated clusters
// TODO: we can remove liteswap check when we can start serverless GPU clusters via API.
if !opts.IsServerlessMode() && opts.Liteswap == "" {
if !opts.IsServerlessMode() {
err := checkClusterState(ctx, client, opts.ClusterID, opts.AutoStartCluster)
if err != nil {
return err
Expand Down Expand Up @@ -247,12 +249,100 @@ func Run(ctx context.Context, client *databricks.WorkspaceClient, opts ClientOpt

if opts.ProxyMode {
return runSSHProxy(ctx, client, serverPort, clusterID, opts)
} else if opts.IDE != "" {
return runIDE(ctx, client, userName, keyPath, serverPort, clusterID, opts)
} else {
cmdio.LogString(ctx, fmt.Sprintf("Additional SSH arguments: %v", opts.AdditionalArgs))
return spawnSSHClient(ctx, userName, keyPath, serverPort, clusterID, opts)
}
}

func runIDE(ctx context.Context, client *databricks.WorkspaceClient, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error {
// Validate IDE value
if opts.IDE != "vscode" && opts.IDE != "cursor" {
return fmt.Errorf("invalid IDE value: %s, expected 'vscode' or 'cursor'", opts.IDE)
}

// Get connection name
connectionName := opts.SessionIdentifier()
if connectionName == "" {
return errors.New("connection name is required for IDE integration")
}

// Get Databricks user name for the workspace path
currentUser, err := client.CurrentUser.Me(ctx)
if err != nil {
return fmt.Errorf("failed to get current user: %w", err)
}
databricksUserName := currentUser.UserName

// Ensure SSH config entry exists
configPath, err := sshconfig.GetMainConfigPath()
if err != nil {
return fmt.Errorf("failed to get SSH config path: %w", err)
}

err = ensureSSHConfigEntry(ctx, configPath, connectionName, userName, keyPath, serverPort, clusterID, opts)
if err != nil {
return fmt.Errorf("failed to ensure SSH config entry: %w", err)
}

// Determine the IDE command
ideCommand := "code"
if opts.IDE == "cursor" {
ideCommand = "cursor"
}

// Construct the remote SSH URI
// Format: ssh-remote+<server_user_name>@<connection_name> /Workspace/Users/<databricks_user_name>/
remoteURI := fmt.Sprintf("ssh-remote+%s@%s", userName, connectionName)
remotePath := fmt.Sprintf("/Workspace/Users/%s/", databricksUserName)

cmdio.LogString(ctx, fmt.Sprintf("Launching %s with remote URI: %s and path: %s", opts.IDE, remoteURI, remotePath))

// Launch the IDE
ideCmd := exec.CommandContext(ctx, ideCommand, "--remote", remoteURI, remotePath)
ideCmd.Stdout = os.Stdout
ideCmd.Stderr = os.Stderr

return ideCmd.Run()
}

func ensureSSHConfigEntry(ctx context.Context, configPath, hostName, userName, keyPath string, serverPort int, clusterID string, opts ClientOptions) error {
// Ensure the Include directive exists in the main SSH config
err := sshconfig.EnsureIncludeDirective(configPath)
if err != nil {
return err
}

// Generate ProxyCommand with server metadata
optsWithMetadata := opts
optsWithMetadata.ServerMetadata = FormatMetadata(userName, serverPort, clusterID)

proxyCommand, err := optsWithMetadata.ToProxyCommand()
if err != nil {
return fmt.Errorf("failed to generate ProxyCommand: %w", err)
}

hostConfig := fmt.Sprintf(`
Host %s
User %s
ConnectTimeout 360
StrictHostKeyChecking accept-new
IdentitiesOnly yes
IdentityFile %q
ProxyCommand %s
`, hostName, userName, keyPath, proxyCommand)

_, err = sshconfig.CreateOrUpdateHostConfig(ctx, hostName, hostConfig, true)
if err != nil {
return err
}

cmdio.LogString(ctx, fmt.Sprintf("Updated SSH config entry for '%s'", hostName))
return nil
}

// getServerMetadata retrieves the server metadata from the workspace and validates it via Driver Proxy.
// sessionID is the unique identifier for the session (cluster ID for dedicated clusters, connection name for serverless).
// For dedicated clusters, clusterID should be the same as sessionID.
Expand Down
Loading
Loading