Skip to content

Commit 3608a6e

Browse files
committed
fix tests and remove unused login funcs
1 parent 2e86aaf commit 3608a6e

3 files changed

Lines changed: 162 additions & 94 deletions

File tree

cmd/src/login.go

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,37 +107,25 @@ type loginParams struct {
107107

108108
type loginFlow func(context.Context, loginParams) error
109109

110-
type loginFlowKind int
111-
112-
const (
113-
loginFlowOAuth loginFlowKind = iota
114-
loginFlowMissingAuth
115-
loginFlowEndpointConflict
116-
loginFlowValidate
117-
)
118-
119110
func loginCmd(ctx context.Context, p loginParams) error {
120111
if err := p.cfg.requireCIAccessToken(); err != nil {
121112
return err
122113
}
123114

124-
_, flow := selectLoginFlow(p)
115+
flow := selectLoginFlow(p)
125116
if err := flow(ctx, p); err != nil {
126117
return err
127118
}
128119
return nil
129120
}
130121

131122
// selectLoginFlow decides what login flow to run based on configured AuthMode.
132-
func selectLoginFlow(p loginParams) (loginFlowKind, loginFlow) {
133-
switch p.cfg.AuthMode() {
134-
case AuthModeOAuth:
135-
return loginFlowOAuth, runOAuthLogin
136-
case AuthModeAccessToken:
137-
return loginFlowValidate, runValidatedLogin
138-
default:
139-
return loginFlowMissingAuth, runMissingAuthLogin
123+
124+
func selectLoginFlow(p loginParams) loginFlow {
125+
if p.cfg.AuthMode() == AuthModeAccessToken {
126+
return runValidatedLogin
140127
}
128+
return runOAuthLogin
141129
}
142130

143131
func printLoginProblem(out io.Writer, problem string) {

cmd/src/login_test.go

Lines changed: 156 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/http"
99
"net/http/httptest"
1010
"net/url"
11+
"os"
1112
"strings"
1213
"testing"
1314
"time"
@@ -25,34 +26,86 @@ func mustParseURL(t *testing.T, raw string) *url.URL {
2526
return u
2627
}
2728

29+
func loginCommand(t *testing.T) *command {
30+
t.Helper()
31+
for _, cmd := range commands {
32+
if cmd.matches("login") {
33+
return cmd
34+
}
35+
}
36+
t.Fatal("login command not found")
37+
return nil
38+
}
39+
40+
func captureProcessOutput(t *testing.T, fn func() error) (stdout string, stderr string, err error) {
41+
t.Helper()
42+
43+
stdoutR, stdoutW, err := os.Pipe()
44+
if err != nil {
45+
t.Fatal(err)
46+
}
47+
stderrR, stderrW, err := os.Pipe()
48+
if err != nil {
49+
_ = stdoutR.Close()
50+
_ = stdoutW.Close()
51+
t.Fatal(err)
52+
}
53+
54+
oldStdout := os.Stdout
55+
oldStderr := os.Stderr
56+
os.Stdout = stdoutW
57+
os.Stderr = stderrW
58+
defer func() {
59+
os.Stdout = oldStdout
60+
os.Stderr = oldStderr
61+
}()
62+
63+
err = fn()
64+
65+
_ = stdoutW.Close()
66+
_ = stderrW.Close()
67+
68+
stdoutBytes, readErr := io.ReadAll(stdoutR)
69+
if readErr != nil {
70+
t.Fatal(readErr)
71+
}
72+
stderrBytes, readErr := io.ReadAll(stderrR)
73+
if readErr != nil {
74+
t.Fatal(readErr)
75+
}
76+
77+
return strings.TrimSpace(string(stdoutBytes)), strings.TrimSpace(string(stderrBytes)), err
78+
}
79+
80+
func runLoginHandler(t *testing.T, cfgValue *config, args ...string) (stdout string, stderr string, err error) {
81+
t.Helper()
82+
83+
oldCfg := cfg
84+
cfg = cfgValue
85+
t.Cleanup(func() { cfg = oldCfg })
86+
87+
return captureProcessOutput(t, func() error {
88+
return loginCommand(t).handler(args)
89+
})
90+
}
91+
2892
func TestLogin(t *testing.T) {
29-
check := func(t *testing.T, cfg *config, endpointArgURL *url.URL) (output string, err error) {
93+
check := func(t *testing.T, cfg *config) (output string, err error) {
3094
t.Helper()
3195

3296
var out bytes.Buffer
3397
err = loginCmd(context.Background(), loginParams{
34-
cfg: cfg,
35-
client: cfg.apiClient(nil, io.Discard),
36-
out: &out,
37-
oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")},
38-
loginEndpointURL: endpointArgURL,
98+
cfg: cfg,
99+
client: cfg.apiClient(nil, io.Discard),
100+
out: &out,
101+
oauthClient: fakeOAuthClient{startErr: fmt.Errorf("oauth unavailable")},
39102
})
40103
return strings.TrimSpace(out.String()), err
41104
}
42105

43-
t.Run("different endpoint in config vs. arg", func(t *testing.T) {
44-
out, err := check(t, &config{endpointURL: &url.URL{Scheme: "https", Host: "example.com"}}, &url.URL{Scheme: "https", Host: "sourcegraph.example.com"})
45-
if err == nil {
46-
t.Fatal(err)
47-
}
48-
if !strings.Contains(out, "The configured endpoint is https://example.com, not https://sourcegraph.example.com.") {
49-
t.Errorf("got output %q, want configured endpoint error", out)
50-
}
51-
})
52-
53106
t.Run("no access token triggers oauth flow", func(t *testing.T) {
54107
u := &url.URL{Scheme: "https", Host: "example.com"}
55-
out, err := check(t, &config{endpointURL: u}, u)
108+
out, err := check(t, &config{endpointURL: u})
56109
if err == nil {
57110
t.Fatal(err)
58111
}
@@ -63,7 +116,7 @@ func TestLogin(t *testing.T) {
63116

64117
t.Run("CI requires access token", func(t *testing.T) {
65118
u := &url.URL{Scheme: "https", Host: "example.com"}
66-
out, err := check(t, &config{endpointURL: u, inCI: true}, u)
119+
out, err := check(t, &config{endpointURL: u, inCI: true})
67120
if err != errCIAccessTokenRequired {
68121
t.Fatalf("err = %v, want %v", err, errCIAccessTokenRequired)
69122
}
@@ -72,28 +125,14 @@ func TestLogin(t *testing.T) {
72125
}
73126
})
74127

75-
t.Run("warning when using config file", func(t *testing.T) {
76-
endpoint := &url.URL{Scheme: "https", Host: "example.com"}
77-
out, err := check(t, &config{endpointURL: endpoint, configFilePath: "f"}, endpoint)
78-
if err != cmderrors.ExitCode1 {
79-
t.Fatal(err)
80-
}
81-
if !strings.Contains(out, "Configuring src with a JSON file is deprecated") {
82-
t.Errorf("got output %q, want deprecation warning", out)
83-
}
84-
if !strings.Contains(out, "OAuth Device flow authentication failed:") {
85-
t.Errorf("got output %q, want oauth failure output", out)
86-
}
87-
})
88-
89128
t.Run("invalid access token", func(t *testing.T) {
90129
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91130
http.Error(w, "", http.StatusUnauthorized)
92131
}))
93132
defer s.Close()
94133

95134
u := mustParseURL(t, s.URL)
96-
out, err := check(t, &config{endpointURL: u, accessToken: "x"}, u)
135+
out, err := check(t, &config{endpointURL: u, accessToken: "x"})
97136
if err != cmderrors.ExitCode1 {
98137
t.Fatal(err)
99138
}
@@ -111,11 +150,11 @@ func TestLogin(t *testing.T) {
111150
defer s.Close()
112151

113152
u := mustParseURL(t, s.URL)
114-
out, err := check(t, &config{endpointURL: u, accessToken: "x"}, u)
153+
out, err := check(t, &config{endpointURL: u, accessToken: "x"})
115154
if err != nil {
116155
t.Fatal(err)
117156
}
118-
wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=$ENDPOINT"
157+
wantOut := "✔︎ Authenticated as alice on $ENDPOINT"
119158
wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL)
120159
if out != wantOut {
121160
t.Errorf("got output %q, want %q", out, wantOut)
@@ -156,14 +195,95 @@ func TestLogin(t *testing.T) {
156195
t.Fatal("expected stored oauth token to avoid device flow")
157196
}
158197
gotOut := strings.TrimSpace(out.String())
159-
wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n✔︎ Authenticated with OAuth credentials\n\n💡 Tip: To use this endpoint in your shell, run:\n\n export SRC_ENDPOINT=$ENDPOINT"
198+
wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n\n\n✔︎ Authenticated with OAuth credentials"
160199
wantOut = strings.ReplaceAll(wantOut, "$ENDPOINT", s.URL)
161200
if gotOut != wantOut {
162201
t.Errorf("got output %q, want %q", gotOut, wantOut)
163202
}
164203
})
165204
}
166205

206+
func TestLoginHandler(t *testing.T) {
207+
t.Run("warns when login endpoint differs from configured endpoint", func(t *testing.T) {
208+
t.Setenv("SRC_ENDPOINT", "https://example.com")
209+
210+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
211+
fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`)
212+
}))
213+
defer s.Close()
214+
215+
stdout, stderr, err := runLoginHandler(t, &config{
216+
endpointURL: mustParseURL(t, "https://example.com"),
217+
accessToken: "x",
218+
}, s.URL)
219+
if err != nil {
220+
t.Fatal(err)
221+
}
222+
if !strings.Contains(stderr, "Warning: Logging into "+s.URL+" instead of the configured endpoint https://example.com.") {
223+
t.Fatalf("stderr = %q, want endpoint warning", stderr)
224+
}
225+
if !strings.Contains(stderr, "export SRC_ENDPOINT="+s.URL) {
226+
t.Fatalf("stderr = %q, want shell tip", stderr)
227+
}
228+
if !strings.Contains(stdout, "✔︎ Authenticated as alice on "+s.URL) {
229+
t.Fatalf("stdout = %q, want validation output", stdout)
230+
}
231+
})
232+
233+
t.Run("warns when no SRC_ENDPOINT is configured in the environment", func(t *testing.T) {
234+
if oldValue, ok := os.LookupEnv("SRC_ENDPOINT"); ok {
235+
_ = os.Unsetenv("SRC_ENDPOINT")
236+
t.Cleanup(func() { _ = os.Setenv("SRC_ENDPOINT", oldValue) })
237+
} else {
238+
_ = os.Unsetenv("SRC_ENDPOINT")
239+
}
240+
241+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
242+
fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`)
243+
}))
244+
defer s.Close()
245+
246+
stdout, stderr, err := runLoginHandler(t, &config{
247+
endpointURL: mustParseURL(t, SGDotComEndpoint),
248+
accessToken: "x",
249+
}, s.URL)
250+
if err != nil {
251+
t.Fatal(err)
252+
}
253+
if !strings.Contains(stderr, "Warning: No SRC_ENDPOINT is configured in the environment. Logging in using \""+s.URL+"\".") {
254+
t.Fatalf("stderr = %q, want default-endpoint warning", stderr)
255+
}
256+
if !strings.Contains(stderr, "NOTE: By default src will use \""+SGDotComEndpoint+"\" if SRC_ENDPOINT is not set.") {
257+
t.Fatalf("stderr = %q, want default endpoint note", stderr)
258+
}
259+
if !strings.Contains(stdout, "✔︎ Authenticated as alice on "+s.URL) {
260+
t.Fatalf("stdout = %q, want validation output", stdout)
261+
}
262+
})
263+
264+
t.Run("warns when using config file", func(t *testing.T) {
265+
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
266+
fmt.Fprintln(w, `{"data":{"currentUser":{"username":"alice"}}}`)
267+
}))
268+
defer s.Close()
269+
270+
stdout, stderr, err := runLoginHandler(t, &config{
271+
endpointURL: mustParseURL(t, s.URL),
272+
accessToken: "x",
273+
configFilePath: "f",
274+
})
275+
if err != nil {
276+
t.Fatal(err)
277+
}
278+
if !strings.Contains(stderr, "Configuring src with a JSON file is deprecated") {
279+
t.Fatalf("stderr = %q, want deprecation warning", stderr)
280+
}
281+
if !strings.Contains(stdout, "✔︎ Authenticated as alice on "+s.URL) {
282+
t.Fatalf("stdout = %q, want validation output", stdout)
283+
}
284+
})
285+
}
286+
167287
type fakeOAuthClient struct {
168288
startErr error
169289
startCalled *bool
@@ -192,39 +312,6 @@ func (f fakeOAuthClient) Refresh(context.Context, *oauth.Token) (*oauth.TokenRes
192312
return nil, fmt.Errorf("unexpected call to Refresh")
193313
}
194314

195-
func TestSelectLoginFlow(t *testing.T) {
196-
t.Run("uses oauth flow when no access token is configured", func(t *testing.T) {
197-
params := loginParams{
198-
cfg: &config{endpointURL: mustParseURL(t, "https://example.com")},
199-
}
200-
201-
if got, _ := selectLoginFlow(params); got != loginFlowOAuth {
202-
t.Fatalf("flow = %v, want %v", got, loginFlowOAuth)
203-
}
204-
})
205-
206-
t.Run("uses endpoint conflict flow when auth exists for a different endpoint", func(t *testing.T) {
207-
params := loginParams{
208-
cfg: &config{endpointURL: mustParseURL(t, "https://example.com"), accessToken: "x"},
209-
loginEndpointURL: mustParseURL(t, "https://sourcegraph.example.com"),
210-
}
211-
212-
if got, _ := selectLoginFlow(params); got != loginFlowEndpointConflict {
213-
t.Fatalf("flow = %v, want %v", got, loginFlowEndpointConflict)
214-
}
215-
})
216-
217-
t.Run("uses validation flow when auth exists for the selected endpoint", func(t *testing.T) {
218-
params := loginParams{
219-
cfg: &config{endpointURL: mustParseURL(t, "https://example.com"), accessToken: "x"},
220-
}
221-
222-
if got, _ := selectLoginFlow(params); got != loginFlowValidate {
223-
t.Fatalf("flow = %v, want %v", got, loginFlowValidate)
224-
}
225-
})
226-
}
227-
228315
func TestValidateBrowserURL(t *testing.T) {
229316
tests := []struct {
230317
name string

cmd/src/login_validate.go

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,6 @@ import (
1111
"github.com/sourcegraph/src-cli/internal/cmderrors"
1212
)
1313

14-
func runMissingAuthLogin(_ context.Context, p loginParams) error {
15-
fmt.Fprintln(p.out)
16-
printLoginProblem(p.out, "No access token is configured.")
17-
fmt.Fprintln(p.out, loginAccessTokenMessage(p.cfg.endpointURL))
18-
return cmderrors.ExitCode1
19-
}
20-
2114
func runValidatedLogin(ctx context.Context, p loginParams) error {
2215
return validateCurrentUser(ctx, p.client, p.out, p.cfg.endpointURL)
2316
}

0 commit comments

Comments
 (0)