Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions cmd/kubectl_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,19 @@ Example:
$ rancher token delete all
`

const (
deviceAuthFlow = "devicecode"
authCodeFlow = "authcode"
)

type LoginInput struct {
server string
userID string
clusterID string
authProvider string
caCerts string
skipVerify bool
authFlow string // devicecode or authcode.
}

const (
Expand Down Expand Up @@ -120,6 +126,10 @@ func CredentialCommand() cli.Command {
Name: "auth-provider",
Usage: "Name of Auth Provider to use for authentication",
},
cli.StringFlag{
Name: "auth-flow",
Usage: "Auth flow to use for OAuth providers: 'devicecode' (default) or 'authcode'",
},
cli.StringFlag{
Name: "cacerts",
Usage: "Location of CaCerts to use",
Expand Down Expand Up @@ -180,6 +190,7 @@ func runCredential(ctx *cli.Context) error {
authProvider: ctx.String("auth-provider"),
caCerts: ctx.String("cacerts"),
skipVerify: ctx.Bool("skip-verify"),
authFlow: ctx.String("auth-flow"),
}

tlsConfig, err := getTLSConfig(input.skipVerify, input.caCerts)
Expand Down Expand Up @@ -492,6 +503,7 @@ func samlAuth(client *http.Client, input *LoginInput, useV1Public bool) (managem

interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
defer signal.Stop(interrupt)

// Timeout for the login flow.
timeout := time.NewTimer(15 * time.Minute)
Expand Down
250 changes: 250 additions & 0 deletions cmd/kubectl_token_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,265 @@ package cmd
import (
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"runtime"
"strconv"
"strings"
"sync"
"time"

apiv3 "github.com/rancher/rancher/pkg/apis/management.cattle.io/v3"
managementClient "github.com/rancher/rancher/pkg/client/generated/management/v3"
"github.com/sirupsen/logrus"
"golang.org/x/oauth2"
)

const (
oauthCodeFlowTimeout = 5 * time.Minute
oauthCodeExchangeTimeout = 30 * time.Second
callbackServerShutdownTimeout = time.Second
)

// oauthAuth dispatches the OAuth authentication flow based on the auth flow type.
func oauthAuth(client *http.Client, input *LoginInput, provider TypedProvider, useV1Public bool) (*managementClient.Token, error) {
if input.authFlow == "" { // The flag has precedence over the env variable.
input.authFlow = os.Getenv("CATTLE_OAUTH_AUTH_FLOW")
}
input.authFlow = strings.ToLower(input.authFlow)

switch input.authFlow {
case authCodeFlow:
return oauthAuthCodeAuth(client, input, provider, oauthCodeFlowTimeout, useV1Public, openBrowser)
case deviceAuthFlow, "": // Default to device code flow if not specified.
return oauthDeviceCodeAuth(client, input, provider, useV1Public)
default:
return nil, fmt.Errorf("invalid auth-flow value: %s", input.authFlow)
}
}

func getCallbackPort() (int, error) {
env := os.Getenv("CATTLE_OAUTH_CALLBACK_PORT")
if env == "" {
return 0, nil // Use random port
}

port, err := strconv.Atoi(env)
if err != nil {
return 0, fmt.Errorf("invalid callback port value: %w", err)
}
if port < 0 || port > 65535 {
return 0, errors.New("callback port value must be between 0 and 65535")
}
if port > 0 && port < 1024 {
logrus.Warnf("Using privileged port %d may require elevated permissions", port)
}

return port, nil
}

// oauthAuthCodeAuth implements the authorization code flow for OAuth authentication.
func oauthAuthCodeAuth(
client *http.Client,
input *LoginInput,
provider TypedProvider,
timeoutAfter time.Duration,
useV1Public bool,
openBrowser openBrowserFunc,
) (*managementClient.Token, error) {
oauthConfig, err := newOauthConfig(provider)
if err != nil {
return nil, fmt.Errorf("failed to create oauth config: %w", err)
}

callbackPort, err := getCallbackPort()
if err != nil {
return nil, err
}

// Generate PKCE verifier (43-128 chars, cryptographically random).
verifier := oauth2.GenerateVerifier()

// Start a local callback server.
// Note: RFC 8252 Section 7.3 explicitly allows HTTP for localhost
// https://datatracker.ietf.org/doc/html/rfc8252#section-7.3
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", callbackPort))
if err != nil {
return nil, fmt.Errorf("failed to start local callback server on port %d: %w", callbackPort, err)
}
defer listener.Close()

if callbackPort == 0 {
callbackPort = listener.Addr().(*net.TCPAddr).Port
}

oauthConfig.RedirectURL = fmt.Sprintf("http://localhost:%d", callbackPort)

// Generate state for CSRF protection.
state, err := generateState()
if err != nil {
return nil, err
}

// Build the authorization URL with PKCE challenge.
authURL := oauthConfig.AuthCodeURL(state, oauth2.S256ChallengeOption(verifier))

// Start the callback server.
resultCh := make(chan callbackResult, 1)
srv := startCallbackServer(listener, state, resultCh)
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), callbackServerShutdownTimeout)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
srv.Close() // Force close if graceful shutdown fails.
}
}()

// Open the user's browser.
customPrint("\nOpening browser for authentication...\n")
if err := openBrowser(authURL); err != nil {
logrus.Debugf("Failed to open browser: %v", err)
customPrint(fmt.Sprintf("Failed to open browser automatically. Please open the following URL manually:\n%s\n", authURL))
}

// Wait for the callback with the authorization code.
interrupt := make(chan os.Signal, 1)
signal.Notify(interrupt, os.Interrupt)
defer signal.Stop(interrupt)

timeout := time.NewTimer(timeoutAfter)
defer timeout.Stop()

select {
case result := <-resultCh:
if result.err != nil {
return nil, result.err
}

// Exchange the authorization code for tokens using the PKCE verifier.
ctx, cancel := context.WithTimeout(context.Background(), oauthCodeExchangeTimeout)
defer cancel()
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)

oauthToken, err := oauthConfig.Exchange(ctx, result.code, oauth2.VerifierOption(verifier))
if err != nil {
return nil, fmt.Errorf("failed to exchange authorization code for token: %w", err)
}

// Send the id_token to Rancher to get a Rancher token.
return rancherLogin(client, input, oauthToken, useV1Public)
case <-timeout.C:
return nil, errors.New("timed out waiting for browser authentication")
case <-interrupt:
return nil, errors.New("authentication interrupted by user")
}
}

// generateState creates a random string to be used as the OAuth state parameter for CSRF protection.
func generateState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("failed to generate random state: %w", err)
}
return base64.RawURLEncoding.EncodeToString(b), nil
}

type openBrowserFunc func(url string) error

// openBrowser attempts to open the specified URL in the user's default browser, with support for Windows, macOS, and Linux.
func openBrowser(url string) error {
var cmd string
var args []string

switch runtime.GOOS {
case "darwin":
cmd = "open"
args = []string{url}
case "linux":
cmd = "xdg-open"
args = []string{url}
case "windows":
cmd = "cmd"
args = []string{"/c", "start", "", url}
default:
return fmt.Errorf("unsupported platform: %s", runtime.GOOS)
}

return exec.Command(cmd, args...).Start()
}

// callbackResult is used to communicate the result of the OAuth callback handling back to the main authentication flow.
type callbackResult struct {
code string
state string
err error
}

// startCallbackServer starts an HTTP server to listen for the OAuth callback and validates the state parameter for CSRF protection.
func startCallbackServer(listener net.Listener, expectedState string, resultCh chan<- callbackResult) *http.Server {
mux := http.NewServeMux()
srv := &http.Server{
Handler: mux,
ReadTimeout: 15 * time.Second,
WriteTimeout: 15 * time.Second,
IdleTimeout: 30 * time.Second,
}

var once sync.Once

mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
defer func() {
if r := recover(); r != nil {
resultCh <- callbackResult{err: fmt.Errorf("panic in callback handler: %v", r)}
}
}()

once.Do(func() {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect only a SINGLE request.

query := r.URL.Query()

// Check for error response from the IdP.
if errCode := query.Get("error"); errCode != "" {
errDesc := query.Get("error_description")
fmt.Fprintf(w, "<html><body><h1>Authentication Failed</h1><p>%s: %s</p><p>You can close this window.</p></body></html>", errCode, errDesc)
resultCh <- callbackResult{err: fmt.Errorf("authentication error: %s: %s", errCode, errDesc)}
return
}

// Validate state for CSRF protection.
state := query.Get("state")
if state != expectedState {
http.Error(w, "Invalid state parameter", http.StatusBadRequest)
resultCh <- callbackResult{err: errors.New("invalid state parameter in callback (possible CSRF attack)")}
return
}

code := query.Get("code")
if code == "" {
http.Error(w, "Missing authorization code", http.StatusBadRequest)
resultCh <- callbackResult{err: errors.New("missing authorization code in callback")}
return
}

fmt.Fprint(w, "<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>")
resultCh <- callbackResult{code: code, state: state}
})
})

go srv.Serve(listener)

return srv
}

// oauthDeviceCodeAuth implements the device code flow for OAuth authentication.
func oauthDeviceCodeAuth(client *http.Client, input *LoginInput, provider TypedProvider, useV1Public bool) (*managementClient.Token, error) {
oauthConfig, err := newOauthConfig(provider)
if err != nil {
return nil, fmt.Errorf("failed to create oauth config: %w", err)
Expand Down Expand Up @@ -66,6 +314,7 @@ func newOauthConfig(provider TypedProvider) (*oauth2.Config, error) {
}, nil
}

// rancherLogin sends the obtained OAuth token to Rancher to exchange it for a Rancher token that can be used for API authentication.
func rancherLogin(client *http.Client, input *LoginInput, oauthToken *oauth2.Token, useV1Public bool) (*managementClient.Token, error) {
reqURL := fmt.Sprintf(loginURL, input.server)
if !useV1Public {
Expand All @@ -91,6 +340,7 @@ func rancherLogin(client *http.Client, input *LoginInput, oauthToken *oauth2.Tok
if err != nil {
return nil, fmt.Errorf("error creating request: %w", err)
}
req.Header.Set("Content-Type", "application/json")

resp, respBody, err := doRequest(client, req)
if err == nil && resp.StatusCode != http.StatusCreated {
Expand Down
Loading