88 "net/http"
99 "net/http/httptest"
1010 "net/url"
11+ "os"
1112 "strings"
1213 "testing"
1314 "time"
@@ -25,34 +26,86 @@ func mustParseURL(t *testing.T, raw string) *url.URL {
2526 return u
2627}
2728
29+ func loginCommand (t * testing.T ) * command {
30+ t .Helper ()
31+ for _ , cmd := range commands {
32+ if cmd .matches ("login" ) {
33+ return cmd
34+ }
35+ }
36+ t .Fatal ("login command not found" )
37+ return nil
38+ }
39+
40+ func captureProcessOutput (t * testing.T , fn func () error ) (stdout string , stderr string , err error ) {
41+ t .Helper ()
42+
43+ stdoutR , stdoutW , err := os .Pipe ()
44+ if err != nil {
45+ t .Fatal (err )
46+ }
47+ stderrR , stderrW , err := os .Pipe ()
48+ if err != nil {
49+ _ = stdoutR .Close ()
50+ _ = stdoutW .Close ()
51+ t .Fatal (err )
52+ }
53+
54+ oldStdout := os .Stdout
55+ oldStderr := os .Stderr
56+ os .Stdout = stdoutW
57+ os .Stderr = stderrW
58+ defer func () {
59+ os .Stdout = oldStdout
60+ os .Stderr = oldStderr
61+ }()
62+
63+ err = fn ()
64+
65+ _ = stdoutW .Close ()
66+ _ = stderrW .Close ()
67+
68+ stdoutBytes , readErr := io .ReadAll (stdoutR )
69+ if readErr != nil {
70+ t .Fatal (readErr )
71+ }
72+ stderrBytes , readErr := io .ReadAll (stderrR )
73+ if readErr != nil {
74+ t .Fatal (readErr )
75+ }
76+
77+ return strings .TrimSpace (string (stdoutBytes )), strings .TrimSpace (string (stderrBytes )), err
78+ }
79+
80+ func runLoginHandler (t * testing.T , cfgValue * config , args ... string ) (stdout string , stderr string , err error ) {
81+ t .Helper ()
82+
83+ oldCfg := cfg
84+ cfg = cfgValue
85+ t .Cleanup (func () { cfg = oldCfg })
86+
87+ return captureProcessOutput (t , func () error {
88+ return loginCommand (t ).handler (args )
89+ })
90+ }
91+
2892func TestLogin (t * testing.T ) {
29- check := func (t * testing.T , cfg * config , endpointArgURL * url. URL ) (output string , err error ) {
93+ check := func (t * testing.T , cfg * config ) (output string , err error ) {
3094 t .Helper ()
3195
3296 var out bytes.Buffer
3397 err = loginCmd (context .Background (), loginParams {
34- cfg : cfg ,
35- client : cfg .apiClient (nil , io .Discard ),
36- out : & out ,
37- oauthClient : fakeOAuthClient {startErr : fmt .Errorf ("oauth unavailable" )},
38- loginEndpointURL : endpointArgURL ,
98+ cfg : cfg ,
99+ client : cfg .apiClient (nil , io .Discard ),
100+ out : & out ,
101+ oauthClient : fakeOAuthClient {startErr : fmt .Errorf ("oauth unavailable" )},
39102 })
40103 return strings .TrimSpace (out .String ()), err
41104 }
42105
43- t .Run ("different endpoint in config vs. arg" , func (t * testing.T ) {
44- out , err := check (t , & config {endpointURL : & url.URL {Scheme : "https" , Host : "example.com" }}, & url.URL {Scheme : "https" , Host : "sourcegraph.example.com" })
45- if err == nil {
46- t .Fatal (err )
47- }
48- if ! strings .Contains (out , "The configured endpoint is https://example.com, not https://sourcegraph.example.com." ) {
49- t .Errorf ("got output %q, want configured endpoint error" , out )
50- }
51- })
52-
53106 t .Run ("no access token triggers oauth flow" , func (t * testing.T ) {
54107 u := & url.URL {Scheme : "https" , Host : "example.com" }
55- out , err := check (t , & config {endpointURL : u }, u )
108+ out , err := check (t , & config {endpointURL : u })
56109 if err == nil {
57110 t .Fatal (err )
58111 }
@@ -63,7 +116,7 @@ func TestLogin(t *testing.T) {
63116
64117 t .Run ("CI requires access token" , func (t * testing.T ) {
65118 u := & url.URL {Scheme : "https" , Host : "example.com" }
66- out , err := check (t , & config {endpointURL : u , inCI : true }, u )
119+ out , err := check (t , & config {endpointURL : u , inCI : true })
67120 if err != errCIAccessTokenRequired {
68121 t .Fatalf ("err = %v, want %v" , err , errCIAccessTokenRequired )
69122 }
@@ -72,28 +125,14 @@ func TestLogin(t *testing.T) {
72125 }
73126 })
74127
75- t .Run ("warning when using config file" , func (t * testing.T ) {
76- endpoint := & url.URL {Scheme : "https" , Host : "example.com" }
77- out , err := check (t , & config {endpointURL : endpoint , configFilePath : "f" }, endpoint )
78- if err != cmderrors .ExitCode1 {
79- t .Fatal (err )
80- }
81- if ! strings .Contains (out , "Configuring src with a JSON file is deprecated" ) {
82- t .Errorf ("got output %q, want deprecation warning" , out )
83- }
84- if ! strings .Contains (out , "OAuth Device flow authentication failed:" ) {
85- t .Errorf ("got output %q, want oauth failure output" , out )
86- }
87- })
88-
89128 t .Run ("invalid access token" , func (t * testing.T ) {
90129 s := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
91130 http .Error (w , "" , http .StatusUnauthorized )
92131 }))
93132 defer s .Close ()
94133
95134 u := mustParseURL (t , s .URL )
96- out , err := check (t , & config {endpointURL : u , accessToken : "x" }, u )
135+ out , err := check (t , & config {endpointURL : u , accessToken : "x" })
97136 if err != cmderrors .ExitCode1 {
98137 t .Fatal (err )
99138 }
@@ -111,11 +150,11 @@ func TestLogin(t *testing.T) {
111150 defer s .Close ()
112151
113152 u := mustParseURL (t , s .URL )
114- out , err := check (t , & config {endpointURL : u , accessToken : "x" }, u )
153+ out , err := check (t , & config {endpointURL : u , accessToken : "x" })
115154 if err != nil {
116155 t .Fatal (err )
117156 }
118- wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n \n \n 💡 Tip: To use this endpoint in your shell, run: \n \n export SRC_ENDPOINT=$ENDPOINT "
157+ wantOut := "✔︎ Authenticated as alice on $ENDPOINT"
119158 wantOut = strings .ReplaceAll (wantOut , "$ENDPOINT" , s .URL )
120159 if out != wantOut {
121160 t .Errorf ("got output %q, want %q" , out , wantOut )
@@ -156,14 +195,95 @@ func TestLogin(t *testing.T) {
156195 t .Fatal ("expected stored oauth token to avoid device flow" )
157196 }
158197 gotOut := strings .TrimSpace (out .String ())
159- wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n \n \n ✔︎ Authenticated with OAuth credentials\n \n 💡 Tip: To use this endpoint in your shell, run: \n \n export SRC_ENDPOINT=$ENDPOINT "
198+ wantOut := "✔︎ Authenticated as alice on $ENDPOINT\n \n \n ✔︎ Authenticated with OAuth credentials"
160199 wantOut = strings .ReplaceAll (wantOut , "$ENDPOINT" , s .URL )
161200 if gotOut != wantOut {
162201 t .Errorf ("got output %q, want %q" , gotOut , wantOut )
163202 }
164203 })
165204}
166205
206+ func TestLoginHandler (t * testing.T ) {
207+ t .Run ("warns when login endpoint differs from configured endpoint" , func (t * testing.T ) {
208+ t .Setenv ("SRC_ENDPOINT" , "https://example.com" )
209+
210+ s := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
211+ fmt .Fprintln (w , `{"data":{"currentUser":{"username":"alice"}}}` )
212+ }))
213+ defer s .Close ()
214+
215+ stdout , stderr , err := runLoginHandler (t , & config {
216+ endpointURL : mustParseURL (t , "https://example.com" ),
217+ accessToken : "x" ,
218+ }, s .URL )
219+ if err != nil {
220+ t .Fatal (err )
221+ }
222+ if ! strings .Contains (stderr , "Warning: Logging into " + s .URL + " instead of the configured endpoint https://example.com." ) {
223+ t .Fatalf ("stderr = %q, want endpoint warning" , stderr )
224+ }
225+ if ! strings .Contains (stderr , "export SRC_ENDPOINT=" + s .URL ) {
226+ t .Fatalf ("stderr = %q, want shell tip" , stderr )
227+ }
228+ if ! strings .Contains (stdout , "✔︎ Authenticated as alice on " + s .URL ) {
229+ t .Fatalf ("stdout = %q, want validation output" , stdout )
230+ }
231+ })
232+
233+ t .Run ("warns when no SRC_ENDPOINT is configured in the environment" , func (t * testing.T ) {
234+ if oldValue , ok := os .LookupEnv ("SRC_ENDPOINT" ); ok {
235+ _ = os .Unsetenv ("SRC_ENDPOINT" )
236+ t .Cleanup (func () { _ = os .Setenv ("SRC_ENDPOINT" , oldValue ) })
237+ } else {
238+ _ = os .Unsetenv ("SRC_ENDPOINT" )
239+ }
240+
241+ s := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
242+ fmt .Fprintln (w , `{"data":{"currentUser":{"username":"alice"}}}` )
243+ }))
244+ defer s .Close ()
245+
246+ stdout , stderr , err := runLoginHandler (t , & config {
247+ endpointURL : mustParseURL (t , SGDotComEndpoint ),
248+ accessToken : "x" ,
249+ }, s .URL )
250+ if err != nil {
251+ t .Fatal (err )
252+ }
253+ if ! strings .Contains (stderr , "Warning: No SRC_ENDPOINT is configured in the environment. Logging in using \" " + s .URL + "\" ." ) {
254+ t .Fatalf ("stderr = %q, want default-endpoint warning" , stderr )
255+ }
256+ if ! strings .Contains (stderr , "NOTE: By default src will use \" " + SGDotComEndpoint + "\" if SRC_ENDPOINT is not set." ) {
257+ t .Fatalf ("stderr = %q, want default endpoint note" , stderr )
258+ }
259+ if ! strings .Contains (stdout , "✔︎ Authenticated as alice on " + s .URL ) {
260+ t .Fatalf ("stdout = %q, want validation output" , stdout )
261+ }
262+ })
263+
264+ t .Run ("warns when using config file" , func (t * testing.T ) {
265+ s := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
266+ fmt .Fprintln (w , `{"data":{"currentUser":{"username":"alice"}}}` )
267+ }))
268+ defer s .Close ()
269+
270+ stdout , stderr , err := runLoginHandler (t , & config {
271+ endpointURL : mustParseURL (t , s .URL ),
272+ accessToken : "x" ,
273+ configFilePath : "f" ,
274+ })
275+ if err != nil {
276+ t .Fatal (err )
277+ }
278+ if ! strings .Contains (stderr , "Configuring src with a JSON file is deprecated" ) {
279+ t .Fatalf ("stderr = %q, want deprecation warning" , stderr )
280+ }
281+ if ! strings .Contains (stdout , "✔︎ Authenticated as alice on " + s .URL ) {
282+ t .Fatalf ("stdout = %q, want validation output" , stdout )
283+ }
284+ })
285+ }
286+
167287type fakeOAuthClient struct {
168288 startErr error
169289 startCalled * bool
@@ -192,39 +312,6 @@ func (f fakeOAuthClient) Refresh(context.Context, *oauth.Token) (*oauth.TokenRes
192312 return nil , fmt .Errorf ("unexpected call to Refresh" )
193313}
194314
195- func TestSelectLoginFlow (t * testing.T ) {
196- t .Run ("uses oauth flow when no access token is configured" , func (t * testing.T ) {
197- params := loginParams {
198- cfg : & config {endpointURL : mustParseURL (t , "https://example.com" )},
199- }
200-
201- if got , _ := selectLoginFlow (params ); got != loginFlowOAuth {
202- t .Fatalf ("flow = %v, want %v" , got , loginFlowOAuth )
203- }
204- })
205-
206- t .Run ("uses endpoint conflict flow when auth exists for a different endpoint" , func (t * testing.T ) {
207- params := loginParams {
208- cfg : & config {endpointURL : mustParseURL (t , "https://example.com" ), accessToken : "x" },
209- loginEndpointURL : mustParseURL (t , "https://sourcegraph.example.com" ),
210- }
211-
212- if got , _ := selectLoginFlow (params ); got != loginFlowEndpointConflict {
213- t .Fatalf ("flow = %v, want %v" , got , loginFlowEndpointConflict )
214- }
215- })
216-
217- t .Run ("uses validation flow when auth exists for the selected endpoint" , func (t * testing.T ) {
218- params := loginParams {
219- cfg : & config {endpointURL : mustParseURL (t , "https://example.com" ), accessToken : "x" },
220- }
221-
222- if got , _ := selectLoginFlow (params ); got != loginFlowValidate {
223- t .Fatalf ("flow = %v, want %v" , got , loginFlowValidate )
224- }
225- })
226- }
227-
228315func TestValidateBrowserURL (t * testing.T ) {
229316 tests := []struct {
230317 name string
0 commit comments