|
| 1 | +import time |
1 | 2 | from unittest.mock import ANY, AsyncMock, Mock, patch |
2 | 3 |
|
| 4 | +import pytest |
| 5 | +from fastapi import HTTPException |
3 | 6 | from fastapi.testclient import TestClient |
4 | 7 | from onelogin.saml2.auth import OneLogin_Saml2_Auth |
5 | 8 | from onelogin.saml2.settings import OneLogin_Saml2_Settings |
6 | 9 |
|
7 | 10 | from auth.user_identity import NativeUser, issue_user_identity |
8 | 11 | from routers import saml |
| 12 | +from services.mongodb_handler import Collection |
9 | 13 |
|
10 | 14 | SSO_URL = "https://shib.service.uci.edu/idp/profile/SAML2/Redirect/SSO" |
11 | 15 | SAMPLE_SETTINGS = OneLogin_Saml2_Settings( |
@@ -103,3 +107,155 @@ def test_saml_acs_succeeds( |
103 | 107 | ) |
104 | 108 | # response sets appropriate JWT cookie for user identity |
105 | 109 | assert res.headers["Set-Cookie"].startswith("irvinehacks_auth=ey") |
| 110 | + |
| 111 | + |
| 112 | +# One-time code tests |
| 113 | +@patch("services.mongodb_handler.insert", autospec=True) |
| 114 | +async def test_generate_one_time_code_creates_and_stores_code( |
| 115 | + mock_mongodb_insert: AsyncMock, |
| 116 | +) -> None: |
| 117 | + """Test that _generate_one_time_code creates a code and stores it in MongoDB.""" |
| 118 | + user = NativeUser( |
| 119 | + ucinetid="testuser", |
| 120 | + display_name="Test User", |
| 121 | + email="test@uci.edu", |
| 122 | + affiliations=["student"], |
| 123 | + ) |
| 124 | + |
| 125 | + mock_mongodb_insert.return_value = "test_code_id" |
| 126 | + |
| 127 | + code = await saml._generate_one_time_code(user) |
| 128 | + |
| 129 | + # Verify code is generated (should be a URL-safe string) |
| 130 | + assert isinstance(code, str) |
| 131 | + assert len(code) > 0 |
| 132 | + |
| 133 | + # Verify MongoDB insert was called with correct data |
| 134 | + mock_mongodb_insert.assert_awaited_once() |
| 135 | + call_args = mock_mongodb_insert.await_args |
| 136 | + assert call_args is not None |
| 137 | + assert call_args[0][0] == Collection.CODES # Collection.CODES |
| 138 | + code_data = call_args[0][1] # The data dict |
| 139 | + |
| 140 | + assert code_data["code"] == code |
| 141 | + assert code_data["user"]["ucinetid"] == "testuser" |
| 142 | + assert code_data["user"]["display_name"] == "Test User" |
| 143 | + assert code_data["user"]["email"] == "test@uci.edu" |
| 144 | + assert code_data["user"]["affiliations"] == ["student"] |
| 145 | + assert "expires_at" in code_data |
| 146 | + assert "created_at" in code_data |
| 147 | + |
| 148 | + |
| 149 | +@patch("services.mongodb_handler.retrieve_one", autospec=True) |
| 150 | +@patch("services.mongodb_handler.raw_update_one", autospec=True) |
| 151 | +async def test_validate_one_time_code_with_valid_code( |
| 152 | + mock_raw_update_one: AsyncMock, |
| 153 | + mock_retrieve_one: AsyncMock, |
| 154 | +) -> None: |
| 155 | + """Test that _validate_one_time_code works with a valid code.""" |
| 156 | + current_time = time.time() |
| 157 | + code_data = { |
| 158 | + "code": "valid_code_123", |
| 159 | + "user": { |
| 160 | + "ucinetid": "testuser", |
| 161 | + "display_name": "Test User", |
| 162 | + "email": "test@uci.edu", |
| 163 | + "affiliations": ["student"], |
| 164 | + }, |
| 165 | + "expires_at": current_time + 300, # 5 minutes from now |
| 166 | + "created_at": current_time, |
| 167 | + } |
| 168 | + |
| 169 | + mock_retrieve_one.return_value = code_data |
| 170 | + mock_raw_update_one.return_value = True |
| 171 | + |
| 172 | + user = await saml._validate_one_time_code("valid_code_123") |
| 173 | + |
| 174 | + # Verify user is reconstructed correctly |
| 175 | + assert user.ucinetid == "testuser" |
| 176 | + assert user.display_name == "Test User" |
| 177 | + assert user.email == "test@uci.edu" |
| 178 | + assert user.affiliations == ["student"] |
| 179 | + |
| 180 | + # Verify code was removed after validation |
| 181 | + mock_raw_update_one.assert_awaited_once_with( |
| 182 | + Collection.CODES, {"code": "valid_code_123"}, {"$unset": {"code": ""}} |
| 183 | + ) |
| 184 | + |
| 185 | + |
| 186 | +@patch("services.mongodb_handler.retrieve_one", autospec=True) |
| 187 | +async def test_validate_one_time_code_with_invalid_code( |
| 188 | + mock_retrieve_one: AsyncMock, |
| 189 | +) -> None: |
| 190 | + """Test that _validate_one_time_code fails with an invalid code.""" |
| 191 | + mock_retrieve_one.return_value = None # Code not found |
| 192 | + |
| 193 | + with pytest.raises(HTTPException) as exc_info: |
| 194 | + await saml._validate_one_time_code("invalid_code") |
| 195 | + |
| 196 | + assert exc_info.value.status_code == 400 |
| 197 | + assert "Invalid code" in str(exc_info.value.detail) |
| 198 | + |
| 199 | + |
| 200 | +@patch("services.mongodb_handler.retrieve_one", autospec=True) |
| 201 | +@patch("services.mongodb_handler.raw_update_one", autospec=True) |
| 202 | +async def test_validate_one_time_code_with_expired_code( |
| 203 | + mock_raw_update_one: AsyncMock, |
| 204 | + mock_retrieve_one: AsyncMock, |
| 205 | +) -> None: |
| 206 | + """Test that _validate_one_time_code fails with an expired code.""" |
| 207 | + current_time = time.time() |
| 208 | + code_data = { |
| 209 | + "code": "expired_code_123", |
| 210 | + "user": { |
| 211 | + "ucinetid": "testuser", |
| 212 | + "display_name": "Test User", |
| 213 | + "email": "test@uci.edu", |
| 214 | + "affiliations": ["student"], |
| 215 | + }, |
| 216 | + "expires_at": current_time - 100, # Expired 100 seconds ago |
| 217 | + "created_at": current_time - 400, |
| 218 | + } |
| 219 | + |
| 220 | + mock_retrieve_one.return_value = code_data |
| 221 | + mock_raw_update_one.return_value = True |
| 222 | + |
| 223 | + with pytest.raises(HTTPException) as exc_info: |
| 224 | + await saml._validate_one_time_code("expired_code_123") |
| 225 | + |
| 226 | + assert exc_info.value.status_code == 400 |
| 227 | + assert "Code expired" in str(exc_info.value.detail) |
| 228 | + |
| 229 | + # Verify expired code was cleaned up |
| 230 | + mock_raw_update_one.assert_awaited_once_with( |
| 231 | + Collection.CODES, {"code": "expired_code_123"}, {"$unset": {"code": ""}} |
| 232 | + ) |
| 233 | + |
| 234 | + |
| 235 | +@patch("routers.saml._validate_one_time_code", autospec=True) |
| 236 | +@patch("routers.saml.issue_user_identity", autospec=True) |
| 237 | +def test_exchange_code_with_valid_code( |
| 238 | + mock_issue_user_identity: Mock, |
| 239 | + mock_validate_one_time_code: AsyncMock, |
| 240 | +) -> None: |
| 241 | + """Test that /exchange endpoint works with a valid code.""" |
| 242 | + user = NativeUser( |
| 243 | + ucinetid="testuser", |
| 244 | + display_name="Test User", |
| 245 | + email="test@uci.edu", |
| 246 | + affiliations=["student"], |
| 247 | + ) |
| 248 | + |
| 249 | + mock_validate_one_time_code.return_value = user |
| 250 | + mock_issue_user_identity.side_effect = issue_user_identity |
| 251 | + |
| 252 | + res = client.get("/exchange?code=valid_code_123") |
| 253 | + |
| 254 | + assert res.status_code == 303 |
| 255 | + assert res.headers["location"] == "/" |
| 256 | + |
| 257 | + # Verify user identity was issued |
| 258 | + mock_issue_user_identity.assert_called_once_with(user, ANY) |
| 259 | + |
| 260 | + # Verify JWT cookie was set |
| 261 | + assert res.headers["Set-Cookie"].startswith("irvinehacks_auth=ey") |
0 commit comments