Skip to content

Commit 732dfac

Browse files
authored
Sso one time (#658)
* Created functions to create one time codes * Created exchange endpoint * Renamed method calls and created unit tests * Removed some comments * Created check for empty code * Enabled SSO * Added logs
1 parent 729ed84 commit 732dfac

File tree

4 files changed

+260
-4
lines changed

4 files changed

+260
-4
lines changed

apps/api/docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ services:
2828
environment:
2929
MONGODB_URI: mongodb://root:example@mongo:27017
3030
DEPLOYMENT: LOCAL
31-
UCI_SSO_ENABLED: "false"
31+
UCI_SSO_ENABLED: "true"
3232

3333
volumes:
3434
mongodb_data_volume: {}

apps/api/src/routers/saml.py

Lines changed: 102 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import json
22
import os
3+
import secrets
4+
import time
35
from urllib.parse import urlparse
46
from functools import lru_cache
57
from logging import getLogger
@@ -28,6 +30,9 @@
2830
ALLOWED_RELAY_HOSTS = {"zothacks.com", "www.zothacks.com", "localhost"}
2931

3032

33+
ONE_TIME_CODE_TTL = 5 * 60 # in seconds
34+
35+
3136
def _is_valid_relay_state(relay_state: str) -> bool:
3237
# allow same-origin relative paths
3338
if relay_state.startswith("/"):
@@ -113,6 +118,80 @@ async def _update_last_login(user: NativeUser) -> None:
113118
)
114119

115120

121+
async def _generate_one_time_code(user: NativeUser) -> str:
122+
"""Generate a secure one-time code and store it in MongoDB."""
123+
code = secrets.token_urlsafe(32)
124+
expires_at = time.time() + ONE_TIME_CODE_TTL
125+
126+
code_data = {
127+
"code": code,
128+
"user": {
129+
"ucinetid": user.ucinetid,
130+
"display_name": user.display_name,
131+
"email": user.email,
132+
"affiliations": user.affiliations,
133+
},
134+
"expires_at": expires_at,
135+
"created_at": time.time(),
136+
}
137+
138+
try:
139+
await mongodb_handler.insert(Collection.CODES, code_data)
140+
log.info("Generated one-time code")
141+
return code
142+
except Exception as e:
143+
log.error(f"Failed to store one-time code in MongoDB: {e}")
144+
raise HTTPException(
145+
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to generate one-time code"
146+
)
147+
148+
149+
async def _validate_one_time_code(code: str) -> NativeUser:
150+
"""Validate the one-time code and return the associated user."""
151+
try:
152+
if not code:
153+
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid code")
154+
155+
code_data = await mongodb_handler.retrieve_one(Collection.CODES, {"code": code})
156+
157+
if not code_data:
158+
log.info("Code not found")
159+
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid code")
160+
161+
current_time = time.time()
162+
if current_time > code_data["expires_at"]:
163+
# Remove if expired
164+
await mongodb_handler.raw_update_one(
165+
Collection.CODES, {"code": code}, {"$unset": {"code": ""}}
166+
)
167+
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Code expired")
168+
169+
user_data = code_data["user"]
170+
user = NativeUser(
171+
ucinetid=user_data["ucinetid"],
172+
display_name=user_data["display_name"],
173+
email=user_data["email"],
174+
affiliations=user_data["affiliations"],
175+
)
176+
177+
# Remove the code after successful validation (single-use)
178+
await mongodb_handler.raw_update_one(
179+
Collection.CODES, {"code": code}, {"$unset": {"code": ""}}
180+
)
181+
182+
log.info(f"Validated code: {code}")
183+
184+
return user
185+
186+
except HTTPException:
187+
raise
188+
except Exception as e:
189+
log.error(f"Failed to validate one-time code: {e}")
190+
raise HTTPException(
191+
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to validate one-time code"
192+
)
193+
194+
116195
@router.get("/login")
117196
async def login(req: Request, return_to: str = "/") -> RedirectResponse:
118197
"""Initiate login to SSO identity provider."""
@@ -189,9 +268,18 @@ async def acs(
189268

190269
await _update_last_login(user)
191270

192-
res = RedirectResponse(relay_state, status_code=status.HTTP_303_SEE_OTHER)
193-
issue_user_identity(user, res)
194-
return res
271+
# Generate one-time code if returning to external site
272+
if relay_state.startswith("https://zothacks.com"):
273+
log.info("Relay starts with zothacks, generating one-time code")
274+
code = await _generate_one_time_code(user)
275+
redirect_url = f"{relay_state}?code={code}"
276+
return RedirectResponse(redirect_url, status_code=status.HTTP_303_SEE_OTHER)
277+
else:
278+
# Same-domain redirect: set cookie directly
279+
log.info("Relay from irvinehacks, issuing identity")
280+
res = RedirectResponse(relay_state, status_code=status.HTTP_303_SEE_OTHER)
281+
issue_user_identity(user, res)
282+
return res
195283

196284

197285
@router.get("/sls")
@@ -214,3 +302,14 @@ async def get_saml_metadata() -> Response:
214302
raise HTTPException(500, "Could not prepare SP metadata")
215303

216304
return Response(metadata, media_type="application/xml")
305+
306+
307+
@router.get("/exchange")
308+
async def exchange_code(code: str) -> RedirectResponse:
309+
"""Exchange one-time code for JWT."""
310+
log.info(f"Attempting exchange with code: {code}")
311+
user = await _validate_one_time_code(code)
312+
res = RedirectResponse("/", status_code=status.HTTP_303_SEE_OTHER)
313+
log.info("Issuing user identity")
314+
issue_user_identity(user, res)
315+
return res

apps/api/src/services/mongodb_handler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class Collection(str, Enum):
4747
SETTINGS = "settings"
4848
EVENTS = "events"
4949
EMAILS = "emails"
50+
CODES = "codes"
5051

5152

5253
def get_database() -> AgnosticDatabase[Any]:

apps/api/tests/test_saml.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
1+
import time
12
from unittest.mock import ANY, AsyncMock, Mock, patch
23

4+
import pytest
5+
from fastapi import HTTPException
36
from fastapi.testclient import TestClient
47
from onelogin.saml2.auth import OneLogin_Saml2_Auth
58
from onelogin.saml2.settings import OneLogin_Saml2_Settings
69

710
from auth.user_identity import NativeUser, issue_user_identity
811
from routers import saml
12+
from services.mongodb_handler import Collection
913

1014
SSO_URL = "https://shib.service.uci.edu/idp/profile/SAML2/Redirect/SSO"
1115
SAMPLE_SETTINGS = OneLogin_Saml2_Settings(
@@ -103,3 +107,155 @@ def test_saml_acs_succeeds(
103107
)
104108
# response sets appropriate JWT cookie for user identity
105109
assert res.headers["Set-Cookie"].startswith("irvinehacks_auth=ey")
110+
111+
112+
# One-time code tests
113+
@patch("services.mongodb_handler.insert", autospec=True)
114+
async def test_generate_one_time_code_creates_and_stores_code(
115+
mock_mongodb_insert: AsyncMock,
116+
) -> None:
117+
"""Test that _generate_one_time_code creates a code and stores it in MongoDB."""
118+
user = NativeUser(
119+
ucinetid="testuser",
120+
display_name="Test User",
121+
email="test@uci.edu",
122+
affiliations=["student"],
123+
)
124+
125+
mock_mongodb_insert.return_value = "test_code_id"
126+
127+
code = await saml._generate_one_time_code(user)
128+
129+
# Verify code is generated (should be a URL-safe string)
130+
assert isinstance(code, str)
131+
assert len(code) > 0
132+
133+
# Verify MongoDB insert was called with correct data
134+
mock_mongodb_insert.assert_awaited_once()
135+
call_args = mock_mongodb_insert.await_args
136+
assert call_args is not None
137+
assert call_args[0][0] == Collection.CODES # Collection.CODES
138+
code_data = call_args[0][1] # The data dict
139+
140+
assert code_data["code"] == code
141+
assert code_data["user"]["ucinetid"] == "testuser"
142+
assert code_data["user"]["display_name"] == "Test User"
143+
assert code_data["user"]["email"] == "test@uci.edu"
144+
assert code_data["user"]["affiliations"] == ["student"]
145+
assert "expires_at" in code_data
146+
assert "created_at" in code_data
147+
148+
149+
@patch("services.mongodb_handler.retrieve_one", autospec=True)
150+
@patch("services.mongodb_handler.raw_update_one", autospec=True)
151+
async def test_validate_one_time_code_with_valid_code(
152+
mock_raw_update_one: AsyncMock,
153+
mock_retrieve_one: AsyncMock,
154+
) -> None:
155+
"""Test that _validate_one_time_code works with a valid code."""
156+
current_time = time.time()
157+
code_data = {
158+
"code": "valid_code_123",
159+
"user": {
160+
"ucinetid": "testuser",
161+
"display_name": "Test User",
162+
"email": "test@uci.edu",
163+
"affiliations": ["student"],
164+
},
165+
"expires_at": current_time + 300, # 5 minutes from now
166+
"created_at": current_time,
167+
}
168+
169+
mock_retrieve_one.return_value = code_data
170+
mock_raw_update_one.return_value = True
171+
172+
user = await saml._validate_one_time_code("valid_code_123")
173+
174+
# Verify user is reconstructed correctly
175+
assert user.ucinetid == "testuser"
176+
assert user.display_name == "Test User"
177+
assert user.email == "test@uci.edu"
178+
assert user.affiliations == ["student"]
179+
180+
# Verify code was removed after validation
181+
mock_raw_update_one.assert_awaited_once_with(
182+
Collection.CODES, {"code": "valid_code_123"}, {"$unset": {"code": ""}}
183+
)
184+
185+
186+
@patch("services.mongodb_handler.retrieve_one", autospec=True)
187+
async def test_validate_one_time_code_with_invalid_code(
188+
mock_retrieve_one: AsyncMock,
189+
) -> None:
190+
"""Test that _validate_one_time_code fails with an invalid code."""
191+
mock_retrieve_one.return_value = None # Code not found
192+
193+
with pytest.raises(HTTPException) as exc_info:
194+
await saml._validate_one_time_code("invalid_code")
195+
196+
assert exc_info.value.status_code == 400
197+
assert "Invalid code" in str(exc_info.value.detail)
198+
199+
200+
@patch("services.mongodb_handler.retrieve_one", autospec=True)
201+
@patch("services.mongodb_handler.raw_update_one", autospec=True)
202+
async def test_validate_one_time_code_with_expired_code(
203+
mock_raw_update_one: AsyncMock,
204+
mock_retrieve_one: AsyncMock,
205+
) -> None:
206+
"""Test that _validate_one_time_code fails with an expired code."""
207+
current_time = time.time()
208+
code_data = {
209+
"code": "expired_code_123",
210+
"user": {
211+
"ucinetid": "testuser",
212+
"display_name": "Test User",
213+
"email": "test@uci.edu",
214+
"affiliations": ["student"],
215+
},
216+
"expires_at": current_time - 100, # Expired 100 seconds ago
217+
"created_at": current_time - 400,
218+
}
219+
220+
mock_retrieve_one.return_value = code_data
221+
mock_raw_update_one.return_value = True
222+
223+
with pytest.raises(HTTPException) as exc_info:
224+
await saml._validate_one_time_code("expired_code_123")
225+
226+
assert exc_info.value.status_code == 400
227+
assert "Code expired" in str(exc_info.value.detail)
228+
229+
# Verify expired code was cleaned up
230+
mock_raw_update_one.assert_awaited_once_with(
231+
Collection.CODES, {"code": "expired_code_123"}, {"$unset": {"code": ""}}
232+
)
233+
234+
235+
@patch("routers.saml._validate_one_time_code", autospec=True)
236+
@patch("routers.saml.issue_user_identity", autospec=True)
237+
def test_exchange_code_with_valid_code(
238+
mock_issue_user_identity: Mock,
239+
mock_validate_one_time_code: AsyncMock,
240+
) -> None:
241+
"""Test that /exchange endpoint works with a valid code."""
242+
user = NativeUser(
243+
ucinetid="testuser",
244+
display_name="Test User",
245+
email="test@uci.edu",
246+
affiliations=["student"],
247+
)
248+
249+
mock_validate_one_time_code.return_value = user
250+
mock_issue_user_identity.side_effect = issue_user_identity
251+
252+
res = client.get("/exchange?code=valid_code_123")
253+
254+
assert res.status_code == 303
255+
assert res.headers["location"] == "/"
256+
257+
# Verify user identity was issued
258+
mock_issue_user_identity.assert_called_once_with(user, ANY)
259+
260+
# Verify JWT cookie was set
261+
assert res.headers["Set-Cookie"].startswith("irvinehacks_auth=ey")

0 commit comments

Comments
 (0)