Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 14 additions & 37 deletions cmd/src/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,9 @@ Examples:

$ src login https://sourcegraph.com

Use OAuth device flow to authenticate:
If no access token is configured, 'src login' uses OAuth device flow automatically:

$ src login --oauth https://sourcegraph.com


Override the default client id used during device flow when authenticating:

$ src login --oauth https://sourcegraph.com
$ src login https://sourcegraph.com
`

flagSet := flag.NewFlagSet("login", flag.ExitOnError)
Expand All @@ -47,7 +42,6 @@ Examples:

var (
apiFlags = api.NewFlags(flagSet)
useOAuth = flagSet.Bool("oauth", false, "Use OAuth device flow to obtain an access token interactively")
)

handler := func(args []string) error {
Expand All @@ -69,7 +63,6 @@ Examples:
client: client,
endpoint: endpoint,
out: os.Stdout,
useOAuth: *useOAuth,
apiFlags: apiFlags,
oauthClient: oauth.NewClient(oauth.DefaultClientID),
})
Expand All @@ -87,7 +80,6 @@ type loginParams struct {
client api.Client
endpoint string
out io.Writer
useOAuth bool
apiFlags *api.Flags
oauthClient oauth.Client
}
Expand All @@ -103,46 +95,31 @@ const (
loginFlowValidate
)

var loadStoredOAuthToken = oauth.LoadToken

func loginCmd(ctx context.Context, p loginParams) error {
if p.cfg.ConfigFilePath != "" {
fmt.Fprintln(p.out)
fmt.Fprintf(p.out, "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove %s. See https://github.com/sourcegraph/src-cli#readme for more information.\n", p.cfg.ConfigFilePath)
}

_, flow := selectLoginFlow(ctx, p)
_, flow := selectLoginFlow(p)
return flow(ctx, p)
}

// selectLoginFlow decides what login flow to run based on flags and config.
func selectLoginFlow(ctx context.Context, p loginParams) (loginFlowKind, loginFlow) {
// selectLoginFlow decides what login flow to run based on configured AuthMode.
func selectLoginFlow(p loginParams) (loginFlowKind, loginFlow) {
endpointArg := cleanEndpoint(p.endpoint)

if p.useOAuth {
switch p.cfg.AuthMode() {
case AuthModeOAuth:
return loginFlowOAuth, runOAuthLogin
}
if !hasEffectiveAuth(ctx, p.cfg, endpointArg) {
case AuthModeAccessToken:
if endpointArg != p.cfg.Endpoint {
return loginFlowEndpointConflict, runEndpointConflictLogin
}
return loginFlowValidate, runValidatedLogin
default:
return loginFlowMissingAuth, runMissingAuthLogin
}
if endpointArg != p.cfg.Endpoint {
return loginFlowEndpointConflict, runEndpointConflictLogin
}
return loginFlowValidate, runValidatedLogin
}

// hasEffectiveAuth determines whether we have auth credentials to continue. It first checks for a resolved Access Token in
// config, then it checks for a stored OAuth token.
func hasEffectiveAuth(ctx context.Context, cfg *config, resolvedEndpoint string) bool {
if cfg.AccessToken != "" {
return true
}

if _, err := loadStoredOAuthToken(ctx, resolvedEndpoint); err == nil {
return true
}

return false
}

func printLoginProblem(out io.Writer, problem string) {
Expand All @@ -157,6 +134,6 @@ func loginAccessTokenMessage(endpoint string) string {

To verify that it's working, run the login command again.

Alternatively, you can try logging in using OAuth by running: src login --oauth %s
Alternatively, you can try logging in interactively by running: src login %s
`, endpoint, endpoint, endpoint)
}
18 changes: 16 additions & 2 deletions cmd/src/login_oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"github.com/sourcegraph/src-cli/internal/oauth"
)

var loadStoredOAuthToken = oauth.LoadToken

func runOAuthLogin(ctx context.Context, p loginParams) error {
endpointArg := cleanEndpoint(p.endpoint)
client, err := oauthLoginClient(ctx, p, endpointArg)
Expand All @@ -32,7 +34,15 @@ func runOAuthLogin(ctx context.Context, p loginParams) error {
return nil
}

// oauthLoginClient returns a api.Client with the OAuth token set. It will check secret storage for a token
// and use it if one is present.
// If no token is found, it will start a OAuth Device flow to get a token and storage in secret storage.
func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.Client, error) {
// if we have a stored token, used it. Otherwise run the device flow
if token, err := loadStoredOAuthToken(ctx, endpoint); err == nil {
return newOAuthAPIClient(p, endpoint, token), nil
}

token, err := runOAuthDeviceFlow(ctx, endpoint, p.out, p.oauthClient)
if err != nil {
return nil, err
Expand All @@ -43,15 +53,19 @@ func oauthLoginClient(ctx context.Context, p loginParams, endpoint string) (api.
fmt.Fprintf(p.out, "⚠️ Warning: Failed to store token in keyring store: %q. Continuing with this session only.\n", err)
}

return newOAuthAPIClient(p, endpoint, token), nil
}

func newOAuthAPIClient(p loginParams, endpoint string, token *oauth.Token) api.Client {
return api.NewClient(api.ClientOpts{
Endpoint: p.cfg.Endpoint,
Endpoint: endpoint,
AdditionalHeaders: p.cfg.AdditionalHeaders,
Flags: p.apiFlags,
Out: p.out,
ProxyURL: p.cfg.ProxyURL,
ProxyPath: p.cfg.ProxyPath,
OAuthToken: token,
}), nil
})
}

func runOAuthDeviceFlow(ctx context.Context, endpoint string, out io.Writer, client oauth.Client) (*oauth.Token, error) {
Expand Down
139 changes: 87 additions & 52 deletions cmd/src/login_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"

"github.com/sourcegraph/src-cli/internal/cmderrors"
"github.com/sourcegraph/src-cli/internal/oauth"
Expand All @@ -18,51 +19,47 @@ func TestLogin(t *testing.T) {
check := func(t *testing.T, cfg *config, endpointArg string) (output string, err error) {
t.Helper()

restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
return nil, fmt.Errorf("not found")
})

var out bytes.Buffer
err = loginCmd(context.Background(), loginParams{
cfg: cfg,
client: cfg.apiClient(nil, io.Discard),
endpoint: endpointArg,
out: &out,
oauthClient: oauth.NewClient(oauth.DefaultClientID),
oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")},
})
return strings.TrimSpace(out.String()), err
}

t.Run("different endpoint in config vs. arg", func(t *testing.T) {
out, err := check(t, &config{Endpoint: "https://example.com"}, "https://sourcegraph.example.com")
if err != cmderrors.ExitCode1 {
if err == nil {
t.Fatal(err)
}
wantOut := "❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://sourcegraph.example.com"
if out != wantOut {
t.Errorf("got output %q, want %q", out, wantOut)
if !strings.Contains(out, "OAuth Device flow authentication failed:") {
t.Errorf("got output %q, want oauth failure output", out)
}
})

t.Run("no access token", func(t *testing.T) {
t.Run("no access token triggers oauth flow", func(t *testing.T) {
out, err := check(t, &config{Endpoint: "https://example.com"}, "https://sourcegraph.example.com")
if err != cmderrors.ExitCode1 {
if err == nil {
t.Fatal(err)
}
wantOut := "❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://sourcegraph.example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://sourcegraph.example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://sourcegraph.example.com"
if out != wantOut {
t.Errorf("got output %q, want %q", out, wantOut)
if !strings.Contains(out, "OAuth Device flow authentication failed:") {
t.Errorf("got output %q, want oauth failure output", out)
}
})

t.Run("warning when using config file", func(t *testing.T) {
out, err := check(t, &config{Endpoint: "https://example.com", ConfigFilePath: "f"}, "https://example.com")
if err != cmderrors.ExitCode1 {
if err == nil {
t.Fatal(err)
}
wantOut := "⚠️ Warning: Configuring src with a JSON file is deprecated. Please migrate to using the env vars SRC_ENDPOINT, SRC_ACCESS_TOKEN, and SRC_PROXY instead, and then remove f. See https://github.com/sourcegraph/src-cli#readme for more information.\n\n❌ Problem: No access token is configured.\n\n🛠 To fix: Create an access token by going to https://example.com/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=https://example.com\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth https://example.com"
if out != wantOut {
t.Errorf("got output %q, want %q", out, wantOut)
if !strings.Contains(out, "Configuring src with a JSON file is deprecated") {
t.Errorf("got output %q, want deprecation warning", out)
}
if !strings.Contains(out, "OAuth Device flow authentication failed:") {
t.Errorf("got output %q, want oauth failure output", out)
}
})

Expand All @@ -77,7 +74,7 @@ func TestLogin(t *testing.T) {
if err != cmderrors.ExitCode1 {
t.Fatal(err)
}
wantOut := "❌ Problem: Invalid access token.\n\n🛠 To fix: Create an access token by going to $ENDPOINT/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=$ENDPOINT\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in using OAuth by running: src login --oauth $ENDPOINT\n\n (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)"
wantOut := "❌ Problem: Invalid access token.\n\n🛠 To fix: Create an access token by going to $ENDPOINT/user/settings/tokens, then set the following environment variables in your terminal:\n\n export SRC_ENDPOINT=$ENDPOINT\n export SRC_ACCESS_TOKEN=(your access token)\n\n To verify that it's working, run the login command again.\n\n Alternatively, you can try logging in interactively by running: src login $ENDPOINT\n\n (If you need to supply custom HTTP request headers, see information about SRC_HEADER_* and SRC_HEADERS env vars at https://github.com/sourcegraph/src-cli/blob/main/AUTH_PROXY.md)"
wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", endpoint)
if out != wantOut {
t.Errorf("got output %q, want %q", out, wantOut)
Expand All @@ -101,33 +98,86 @@ func TestLogin(t *testing.T) {
t.Errorf("got output %q, want %q", out, wantOut)
}
})
}

func TestSelectLoginFlow(t *testing.T) {
restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
return nil, fmt.Errorf("not found")
})
t.Run("reuses stored oauth token before device flow", func(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`)
}))
defer s.Close()

t.Run("uses oauth flow when oauth flag is set", func(t *testing.T) {
params := loginParams{
cfg: &config{Endpoint: "https://example.com"},
endpoint: "https://example.com",
useOAuth: true,
}
restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
return &oauth.Token{
Endpoint: s.URL,
ClientID: oauth.DefaultClientID,
AccessToken: "oauth-token",
ExpiresAt: time.Now().Add(time.Hour),
}, nil
})

if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowOAuth {
t.Fatalf("flow = %v, want %v", got, loginFlowOAuth)
startCalled := false
var out bytes.Buffer
err := loginCmd(context.Background(), loginParams{
cfg: &config{Endpoint: s.URL},
client: (&config{Endpoint: s.URL}).apiClient(nil, io.Discard),
endpoint: s.URL,
out: &out,
oauthClient: fakeOAuthClient{
startErr: fmt.Errorf("unexpected call to Start"),
startCalled: &startCalled,
},
})
if err != nil {
t.Fatal(err)
}
if startCalled {
t.Fatal("expected stored oauth token to avoid device flow")
}
gotOut := strings.TrimSpace(out.String())
wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n✔︎ Authenticated with OAuth credentials"
wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL)
if gotOut != wantOut {
t.Errorf("got output %q, want %q", gotOut, wantOut)
}
})
}

type fakeOAuthClient struct {
startErr error
startCalled *bool
}

func (f fakeOAuthClient) ClientID() string {
return oauth.DefaultClientID
}

func (f fakeOAuthClient) Discover(context.Context, string) (*oauth.OIDCConfiguration, error) {
return nil, fmt.Errorf("unexpected call to Discover")
}

func (f fakeOAuthClient) Start(context.Context, string, []string) (*oauth.DeviceAuthResponse, error) {
if f.startCalled != nil {
*f.startCalled = true
}
return nil, f.startErr
}

func (f fakeOAuthClient) Poll(context.Context, string, string, time.Duration, int) (*oauth.TokenResponse, error) {
return nil, fmt.Errorf("unexpected call to Poll")
}

func (f fakeOAuthClient) Refresh(context.Context, *oauth.Token) (*oauth.TokenResponse, error) {
return nil, fmt.Errorf("unexpected call to Refresh")
}

t.Run("uses missing auth flow when auth is unavailable", func(t *testing.T) {
func TestSelectLoginFlow(t *testing.T) {
t.Run("uses oauth flow when no access token is configured", func(t *testing.T) {
params := loginParams{
cfg: &config{Endpoint: "https://example.com"},
endpoint: "https://sourcegraph.example.com",
}

if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowMissingAuth {
t.Fatalf("flow = %v, want %v", got, loginFlowMissingAuth)
if got, _ := selectLoginFlow(params); got != loginFlowOAuth {
t.Fatalf("flow = %v, want %v", got, loginFlowOAuth)
}
})

Expand All @@ -137,7 +187,7 @@ func TestSelectLoginFlow(t *testing.T) {
endpoint: "https://sourcegraph.example.com",
}

if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowEndpointConflict {
if got, _ := selectLoginFlow(params); got != loginFlowEndpointConflict {
t.Fatalf("flow = %v, want %v", got, loginFlowEndpointConflict)
}
})
Expand All @@ -148,22 +198,7 @@ func TestSelectLoginFlow(t *testing.T) {
endpoint: "https://example.com",
}

if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowValidate {
t.Fatalf("flow = %v, want %v", got, loginFlowValidate)
}
})

t.Run("treats stored oauth as effective auth", func(t *testing.T) {
restoreStoredOAuthLoader(t, func(context.Context, string) (*oauth.Token, error) {
return &oauth.Token{AccessToken: "oauth-token"}, nil
})

params := loginParams{
cfg: &config{Endpoint: "https://example.com"},
endpoint: "https://example.com",
}

if got, _ := selectLoginFlow(context.Background(), params); got != loginFlowValidate {
if got, _ := selectLoginFlow(params); got != loginFlowValidate {
t.Fatalf("flow = %v, want %v", got, loginFlowValidate)
}
})
Expand Down
Loading
Loading