Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python-django-sso-example/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ pytz==2021.1
requests==2.25.1
sqlparse==0.4.2
urllib3==1.26.5
workos>=1.23.3
workos>=5.37.0
python-dotenv
187 changes: 185 additions & 2 deletions python-django-sso-example/sso/tests.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,186 @@
from django.test import TestCase
from django.test import TestCase, Client
from django.urls import reverse
from unittest.mock import patch, MagicMock
import os
# Import views module to ensure workos is loaded before patching
from sso import views

# Create your tests here.

class SSOViewTests(TestCase):
def setUp(self):
self.client = Client()
# Set environment variables for testing
os.environ["WORKOS_API_KEY"] = "test_api_key"
os.environ["WORKOS_CLIENT_ID"] = "test_client_id"
os.environ["REDIRECT_URI"] = "http://localhost:8000/auth/callback"

def tearDown(self):
# Clean up environment variables
if "WORKOS_API_KEY" in os.environ:
del os.environ["WORKOS_API_KEY"]
if "WORKOS_CLIENT_ID" in os.environ:
del os.environ["WORKOS_CLIENT_ID"]
if "REDIRECT_URI" in os.environ:
del os.environ["REDIRECT_URI"]

def test_login_no_session(self):
"""Test login view when no session is active"""
response = self.client.get(reverse("login"))
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, "sso/login.html")

def test_login_with_active_session(self):
"""Test login view when session is active"""
session = self.client.session
session["session_active"] = True
session["p_profile"] = {"profile": {"first_name": "Test"}}
session["first_name"] = "Test"
session["raw_profile"] = {"email": "test@example.com"}
session.save()

response = self.client.get(reverse("login"))
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, "sso/login_successful.html")
self.assertIn("p_profile", response.context)
self.assertIn("first_name", response.context)
self.assertIn("raw_profile", response.context)

def test_auth_saml_login(self):
"""Test auth view for SAML login"""
# Create a mock sso object
mock_sso = MagicMock()
mock_sso.get_authorization_url.return_value = "https://api.workos.com/sso/authorize?test=123"

# Create a mock client with sso attribute
mock_client = MagicMock()
mock_client.sso = mock_sso

with patch.object(views, "workos_client", mock_client):
response = self.client.post(
reverse("auth"),
{"login_method": "saml"},
follow=False
)

# Verify get_authorization_url was called with correct params
mock_sso.get_authorization_url.assert_called_once()
call_args = mock_sso.get_authorization_url.call_args
self.assertIn("redirect_uri", call_args.kwargs)
self.assertIn("state", call_args.kwargs)
self.assertIn("organization_id", call_args.kwargs)
self.assertEqual(call_args.kwargs["organization_id"], views.CUSTOMER_ORGANIZATION_ID)
self.assertNotIn("provider", call_args.kwargs)

# Verify redirect response
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "https://api.workos.com/sso/authorize?test=123")

def test_auth_provider_login(self):
"""Test auth view for provider-based login (Google, Microsoft, etc.)"""
# Create a mock sso object
mock_sso = MagicMock()
mock_sso.get_authorization_url.return_value = "https://api.workos.com/sso/authorize?provider=google"

# Create a mock client with sso attribute
mock_client = MagicMock()
mock_client.sso = mock_sso

with patch.object(views, "workos_client", mock_client):
response = self.client.post(
reverse("auth"),
{"login_method": "google"},
follow=False
)

# Verify get_authorization_url was called with correct params
mock_sso.get_authorization_url.assert_called_once()
call_args = mock_sso.get_authorization_url.call_args
self.assertIn("redirect_uri", call_args.kwargs)
self.assertIn("state", call_args.kwargs)
self.assertIn("provider", call_args.kwargs)
self.assertEqual(call_args.kwargs["provider"], "google")
self.assertNotIn("organization_id", call_args.kwargs)

# Verify redirect response
self.assertEqual(response.status_code, 302)
self.assertEqual(response.url, "https://api.workos.com/sso/authorize?provider=google")

def test_auth_callback_success(self):
"""Test auth_callback view with valid code"""
# Mock the profile response - in SDK v5+, ProfileAndToken uses .dict() method
mock_profile = MagicMock()
mock_profile.dict.return_value = {
"profile": {
"first_name": "John",
"last_name": "Doe",
"email": "john.doe@example.com"
},
"access_token": "test_token"
}

# Create a mock sso object
mock_sso = MagicMock()
mock_sso.get_profile_and_token.return_value = mock_profile

# Create a mock client with sso attribute
mock_client = MagicMock()
mock_client.sso = mock_sso

with patch.object(views, "workos_client", mock_client):
response = self.client.get(
reverse("auth_callback"),
{"code": "test_auth_code"},
follow=True
)

# Verify get_profile_and_token was called with the code
mock_sso.get_profile_and_token.assert_called_once_with("test_auth_code")

# Verify session data was set
self.assertTrue(self.client.session.get("session_active"))
self.assertIn("p_profile", self.client.session)
self.assertEqual(self.client.session["first_name"], "John")
self.assertIn("raw_profile", self.client.session)

# Verify redirect to login
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, "sso/login_successful.html")

def test_auth_callback_missing_code(self):
"""Test auth_callback view when code parameter is missing"""
# Create a mock sso object
mock_sso = MagicMock()

# Create a mock client with sso attribute
mock_client = MagicMock()
mock_client.sso = mock_sso

# This should render login page with error message (not raise KeyError)
with patch.object(views, "workos_client", mock_client):
response = self.client.get(reverse("auth_callback"))
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, "sso/login.html")
self.assertIn("error", response.context)
self.assertEqual(response.context["error"], "missing_code")

def test_logout(self):
"""Test logout view clears session and redirects"""
# Set up a session first
session = self.client.session
session["session_active"] = True
session["p_profile"] = {"profile": {"first_name": "Test"}}
session.save()

# Verify session has data
self.assertTrue(self.client.session.get("session_active"))

# Call logout
response = self.client.get(reverse("logout"), follow=True)

# Verify session is cleared
self.assertFalse(self.client.session.get("session_active"))
self.assertNotIn("p_profile", self.client.session)

# Verify redirect to login
self.assertEqual(response.status_code, 200)
self.assertTemplateUsed(response, "sso/login.html")
135 changes: 110 additions & 25 deletions python-django-sso-example/sso/views.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,62 @@
import os
import workos
from workos import WorkOSClient
import json
from django.conf import settings
from django.shortcuts import redirect, render
from django.urls import reverse
from pathlib import Path
from dotenv import load_dotenv

# Load environment variables from .env file if it exists
# BASE_DIR is the project root (where manage.py is located)
# views.py is at: python-django-sso-example/sso/views.py
# So we need to go up 2 levels to get to python-django-sso-example/
BASE_DIR = Path(__file__).resolve().parent.parent
env_path = BASE_DIR / ".env"
load_dotenv(env_path, override=False) # Don't override existing env vars


# Initialize WorkOS client
# Note: In SDK v5+, we use WorkOSClient instance instead of workos.client module
def get_workos_client():
"""Get WorkOS client instance (initialized lazily)"""
if not hasattr(get_workos_client, '_instance'):
# Reload .env file in case it wasn't loaded at import time
load_dotenv(env_path, override=False)

api_key = os.getenv("WORKOS_API_KEY")
client_id = os.getenv("WORKOS_CLIENT_ID")
if not api_key or not client_id:
raise ValueError(
"WorkOS API key and client ID must be set via WORKOS_API_KEY and WORKOS_CLIENT_ID environment variables. "
"Please check your .env file or export these variables."
)
get_workos_client._instance = WorkOSClient(
api_key=api_key,
client_id=client_id
)
return get_workos_client._instance

# For compatibility with other examples, create workos_client variable
# Initialize it if env vars are available, otherwise it will be created on first use
try:
if os.getenv("WORKOS_API_KEY") and os.getenv("WORKOS_CLIENT_ID"):
workos_client = WorkOSClient(
api_key=os.getenv("WORKOS_API_KEY"),
client_id=os.getenv("WORKOS_CLIENT_ID")
)
else:
workos_client = None
except ValueError:
# If env vars aren't set at import time, use lazy initialization
workos_client = None


workos.api_key = os.getenv("WORKOS_API_KEY")
workos.client_id = os.getenv("WORKOS_CLIENT_ID")

# In workos_django/settings.py, you can use DEBUG=True for local development,
# but you must use DEBUG=False in order to test the full authentication flow
# with the WorkOS API.
workos.base_api_url = (
"http://localhost:8000/" if settings.DEBUG else workos.base_api_url
)
# Set custom API base URL for local development
if settings.DEBUG:
os.environ["WORKOS_API_BASE_URL"] = "http://localhost:8000/"

# Constants
# Required: Fill in CUSTOMER_ORGANIZATION_ID for the desired organization from the WorkOS Dashboard

CUSTOMER_ORGANIZATION_ID = "xxx"
CUSTOMER_ORGANIZATION_ID = os.getenv("CUSTOMER_ORGANIZATION_ID")
REDIRECT_URI = os.getenv("REDIRECT_URI")


Expand All @@ -40,29 +77,77 @@ def login(request):


def auth(request):
if not REDIRECT_URI:
return render(
request,
"sso/login.html",
{"error": "configuration_error", "error_description": "REDIRECT_URI is not configured"},
)

login_type = request.POST.get("login_method")
if not login_type:
return render(
request,
"sso/login.html",
{"error": "missing_login_method", "error_description": "Login method is required"},
)

login_type = request.POST["login_method"]
params = {"redirect_uri": REDIRECT_URI, "state": {}}

if login_type == "saml":
params["organization"] = CUSTOMER_ORGANIZATION_ID
if not CUSTOMER_ORGANIZATION_ID:
return render(
request,
"sso/login.html",
{"error": "configuration_error", "error_description": "CUSTOMER_ORGANIZATION_ID is not configured"},
)
params["organization_id"] = CUSTOMER_ORGANIZATION_ID
else:
params["provider"] = login_type

authorization_url = workos.client.sso.get_authorization_url(**params)
client = workos_client if workos_client else get_workos_client()
authorization_url = client.sso.get_authorization_url(**params)

return redirect(authorization_url)


def auth_callback(request):
code = request.GET["code"]
profile = workos.client.sso.get_profile_and_token(code)
p_profile = profile.to_dict()
request.session["p_profile"] = p_profile
request.session["first_name"] = p_profile["profile"]["first_name"]
request.session["raw_profile"] = p_profile["profile"]
request.session["session_active"] = True
return redirect("login")
# Check for error response from WorkOS
if "error" in request.GET:
error = request.GET.get("error")
error_description = request.GET.get("error_description", "An error occurred during authentication")
# Log the error and redirect back to login with error message
return render(
request,
"sso/login.html",
{"error": error, "error_description": error_description},
)

# Get the authorization code
code = request.GET.get("code")
if not code:
return render(
request,
"sso/login.html",
{"error": "missing_code", "error_description": "No authorization code received"},
)

try:
client = workos_client if workos_client else get_workos_client()
profile = client.sso.get_profile_and_token(code)
# In SDK v5+, ProfileAndToken is a Pydantic model - use .dict() to convert to dict
p_profile = profile.dict()
request.session["p_profile"] = p_profile
request.session["first_name"] = p_profile["profile"]["first_name"]
request.session["raw_profile"] = p_profile["profile"]
request.session["session_active"] = True
return redirect("login")
except Exception as e:
return render(
request,
"sso/login.html",
{"error": "authentication_error", "error_description": str(e)},
)


def logout(request):
Expand Down
1 change: 0 additions & 1 deletion python-django-sso-example/workos_django/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
https://docs.djangoproject.com/en/3.1/howto/static-files/#configuring-static-files
"""
DEBUG = False
# DEBUG = True

ALLOWED_HOSTS = ["127.0.0.1", "localhost"]

Expand Down