Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 50 additions & 27 deletions authentication/openshift.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ type OpenShiftAuthenticator struct {
oauth2Config oauth2.Config
cookieName string
handler http.Handler
oauthEnabled bool
}

//nolint:funlen
Expand Down Expand Up @@ -139,27 +140,9 @@ func newOpenshiftAuthenticator(c map[string]interface{}, tenant string,
MaxRetries: 0, // Retry indefinitely.
})

var authURL *url.URL

var tokenURL *url.URL

for b.Reset(); b.Ongoing(); {
authURL, tokenURL, err = openshift.DiscoverOAuth(client)
if err != nil {
level.Error(logger).Log(
"tenant", tenant,
"msg", errors.Wrap(err, "unable to auto discover OpenShift OAuth endpoints"))
registrationRetryCount.WithLabelValues(tenant, OpenShiftAuthenticatorType).Inc()
b.Wait()

continue
}

break
}
authURL, tokenURL, oauthEnabled := discoverOAuthEndpoints(client, logger, tenant, registrationRetryCount, b)

var clientID string

var clientSecret string

for b.Reset(); b.Ongoing(); {
Expand Down Expand Up @@ -221,7 +204,11 @@ func newOpenshiftAuthenticator(c map[string]interface{}, tenant string,
client: client,
config: config,
cookieName: fmt.Sprintf("observatorium_%s", tenant),
oauth2Config: oauth2.Config{
oauthEnabled: oauthEnabled,
}

if oauthEnabled {
osAuthenticator.oauth2Config = oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
Endpoint: oauth2.Endpoint{
Expand All @@ -235,15 +222,17 @@ func newOpenshiftAuthenticator(c map[string]interface{}, tenant string,
defaultOAuthScopeListProjects,
},
RedirectURL: config.RedirectURL,
},
}
r := chi.NewRouter()
r.Use(tracing.WithChiRoutePattern)
r.Handle(loginRoute, osAuthenticator.openshiftLoginHandler())
r.Handle(callbackRoute, osAuthenticator.openshiftCallbackHandler())
osAuthenticator.handler = r

} else {
osAuthenticator.handler = chi.NewRouter()
}

r := chi.NewRouter()
r.Use(tracing.WithChiRoutePattern)
r.Handle(loginRoute, osAuthenticator.openshiftLoginHandler())
r.Handle(callbackRoute, osAuthenticator.openshiftCallbackHandler())
osAuthenticator.handler = r

return osAuthenticator, nil
}

Expand Down Expand Up @@ -442,6 +431,13 @@ func (a OpenShiftAuthenticator) Middleware() Middleware {
// when users went through the OAuth2 flow supported by this
// provider. Observatorium stores a self-signed JWT token on a
// cookie per tenant to identify the subject of incoming requests.
if !a.oauthEnabled {
msg := "OAuth authentication not available"
level.Debug(a.logger).Log("msg", msg)
httperr.PrometheusAPIError(w, msg, http.StatusUnauthorized)
return
}

cookie, err := r.Cookie(a.cookieName)
if err != nil {
tenant, ok := GetTenant(r.Context())
Expand Down Expand Up @@ -541,3 +537,30 @@ func (a OpenShiftAuthenticator) GRPCMiddleware() grpc.StreamServerInterceptor {
func (a OpenShiftAuthenticator) Handler() (string, http.Handler) {
return "/openshift/{tenant}", a.handler
}

func discoverOAuthEndpoints(client *http.Client, logger log.Logger, tenant string, registrationRetryCount *prometheus.CounterVec, b *backoff.Backoff) (*url.URL, *url.URL, bool) {
authURL, tokenURL, err := openshift.DiscoverOAuth(client)
if err != nil {
if strings.Contains(err.Error(), "got 404") || strings.Contains(err.Error(), "OAuth server not found") {
level.Warn(logger).Log(
"tenant", tenant,
"msg", errors.Wrap(err, "OpenShift OAuth endpoint not available, likely using external OIDC authentication. But bearer token authentication will continue to work"))
return nil, nil, false
} else {
// Other errors, retry with backoff
for b.Reset(); b.Ongoing(); {
authURL, tokenURL, err = openshift.DiscoverOAuth(client)
if err != nil {
level.Error(logger).Log(
"tenant", tenant,
"msg", errors.Wrap(err, "unable to auto discover OpenShift OAuth endpoints"))
registrationRetryCount.WithLabelValues(tenant, OpenShiftAuthenticatorType).Inc()
b.Wait()
continue
}
break
}
}
}
return authURL, tokenURL, true
}
7 changes: 5 additions & 2 deletions authentication/openshift/discovery.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (
)

const (
oauthWellKnownPath = "/.well-known/oauth-authorization-server"
OauthWellKnownPath = "/.well-known/oauth-authorization-server"

// ServiceAccountNamespacePath is the path to the default serviceaccount namespace.
ServiceAccountNamespacePath = "/var/run/secrets/kubernetes.io/serviceaccount/namespace"
Expand Down Expand Up @@ -54,7 +54,7 @@ func DiscoverCredentials(name string) (string, string, error) {
// DiscoverOAuth return the authorization and token endpoints of the OpenShift OAuth server.
// Returns an error if requesting the `/.well-known/oauth-authorization-server` fails.
func DiscoverOAuth(client *http.Client) (authURL, tokenURL *url.URL, err error) {
oauthURL := toKubeAPIURLWithPath(oauthWellKnownPath)
oauthURL := toKubeAPIURLWithPath(OauthWellKnownPath)

req, err := http.NewRequest(http.MethodGet, oauthURL.String(), nil)
if err != nil {
Expand All @@ -73,6 +73,9 @@ func DiscoverOAuth(client *http.Client) (authURL, tokenURL *url.URL, err error)
}

if resp.StatusCode < 200 || resp.StatusCode >= 300 {
if resp.StatusCode == 404 {
return nil, nil, fmt.Errorf("OAuth server not found")
}
return nil, nil, fmt.Errorf("got %d %s", resp.StatusCode, body)
}

Expand Down
149 changes: 149 additions & 0 deletions authentication/openshift_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package authentication

import (
"context"
"net"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"

"github.com/efficientgo/core/backoff"
"github.com/go-chi/chi/v5"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"

"github.com/observatorium/api/authentication/openshift"
"github.com/observatorium/api/logger"
)

// redirectTransport redirects all requests to the target host.
type redirectTransport struct {
targetHost string
transport http.RoundTripper
}

func (rt *redirectTransport) RoundTrip(req *http.Request) (*http.Response, error) {
// Redirect request to mock server while keeping the path
req.URL.Host = rt.targetHost
req.URL.Scheme = "http"
return rt.transport.RoundTrip(req)
}

func TestDiscoverOAuthEndpoints_OAuthEnabled(t *testing.T) {
tenant := "tenant"
logger := logger.NewLogger("warn", logger.LogFormatLogfmt, "")
r := chi.NewMux()

r.Get(openshift.OauthWellKnownPath, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
if _, err := w.Write([]byte(`{
"authorization_endpoint": "https://oauth.example.com/authorize",
"token_endpoint": "https://oauth.example.com/token"
}`)); err != nil {
t.Fatalf("failed to write response %v", err)
}
})

mockAPIServer := httptest.NewServer(r)
defer mockAPIServer.Close()

mockURL, err := url.Parse(mockAPIServer.URL)
if err != nil {
t.Fatalf("failed to parse mock server URL: %v", err)
}

// Split host and port for KUBERNETES env vars
host, port, err := net.SplitHostPort(mockURL.Host)
if err != nil {
t.Fatalf("failed to parse mock server address: %v", err)
}

t.Setenv("KUBERNETES_SERVICE_HOST", host)
t.Setenv("KUBERNETES_SERVICE_PORT", port)

retryCounter := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_retries_total",
Help: "Total number of OAuth discovery retries",
},
[]string{"tenant", "type"},
)

client := &http.Client{
Transport: &redirectTransport{
targetHost: mockURL.Host,
transport: http.DefaultTransport,
},
}

b := backoff.New(context.TODO(), backoff.Config{
Min: 500 * time.Millisecond,
Max: 5 * time.Second,
MaxRetries: 0, // Retry indefinitely.
})

authURL, tokenURL, oauthEnabled := discoverOAuthEndpoints(client, logger, tenant, retryCounter, b)

assert.NotNil(t, authURL)
assert.NotNil(t, tokenURL)
assert.True(t, oauthEnabled)
assert.Equal(t, authURL.String(), "https://oauth.example.com/authorize")
assert.Equal(t, tokenURL.String(), "https://oauth.example.com/token")
}

func TestDiscoverOAuthEndpoints_OAuthDisabled(t *testing.T) {
tenant := "tenant"
logger := logger.NewLogger("warn", logger.LogFormatLogfmt, "")
r := chi.NewMux()

r.Get(openshift.OauthWellKnownPath, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
if _, err := w.Write([]byte("404 page not found")); err != nil {
t.Fatalf("failed to write response %v", err)
}
})

mockAPIServer := httptest.NewServer(r)
defer mockAPIServer.Close()

mockURL, err := url.Parse(mockAPIServer.URL)
if err != nil {
t.Fatalf("failed to parse mock server URL: %v", err)
}

// Split host and port for KUBERNETES env vars
host, port, err := net.SplitHostPort(mockURL.Host)
if err != nil {
t.Fatalf("failed to parse mock server address: %v", err)
}

t.Setenv("KUBERNETES_SERVICE_HOST", host)
t.Setenv("KUBERNETES_SERVICE_PORT", port)

retryCounter := prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "test_retries_total",
Help: "Total number of OAuth discovery retries",
},
[]string{"tenant", "type"},
)

client := &http.Client{
Transport: &redirectTransport{
targetHost: mockURL.Host,
transport: http.DefaultTransport,
},
}

b := backoff.New(context.TODO(), backoff.Config{
Min: 500 * time.Millisecond,
Max: 5 * time.Second,
MaxRetries: 0, // Retry indefinitely.
})

authURL, tokenURL, oauthEnabled := discoverOAuthEndpoints(client, logger, tenant, retryCounter, b)

assert.Nil(t, authURL)
assert.Nil(t, tokenURL)
assert.False(t, oauthEnabled)
}