Skip to content

Commit 7573249

Browse files
committed
add mocks and use monkey patch for setting env vars
1 parent 4d9a536 commit 7573249

File tree

4 files changed

+147
-18
lines changed

4 files changed

+147
-18
lines changed

src/fastapi_app/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ async def lifespan(app: FastAPI):
5252
await engine.dispose()
5353

5454

55-
def create_app(is_testing: bool = False):
55+
def create_app(testing: bool = False):
5656
env = Env()
5757

5858
if not os.getenv("RUNNING_IN_PRODUCTION"):
59-
if not is_testing:
59+
if not testing:
6060
env.read_env(".env")
6161
logging.basicConfig(level=logging.INFO)
6262
else:

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import os
22
from pathlib import Path
3-
from unittest.mock import patch
3+
from unittest import mock
44

55
import pytest
66
from fastapi.testclient import TestClient
77
from sqlalchemy.ext.asyncio import async_sessionmaker
88

99
from fastapi_app import create_app
1010
from fastapi_app.globals import global_storage
11+
from tests.mocks import MockAzureCredential
1112

1213
POSTGRES_HOST = "localhost"
1314
POSTGRES_USERNAME = "admin"
@@ -20,34 +21,54 @@
2021

2122

2223
@pytest.fixture(scope="session")
23-
def setup_env():
24-
os.environ["POSTGRES_HOST"] = POSTGRES_HOST
25-
os.environ["POSTGRES_USERNAME"] = POSTGRES_USERNAME
26-
os.environ["POSTGRES_DATABASE"] = POSTGRES_DATABASE
27-
os.environ["POSTGRES_PASSWORD"] = POSTGRES_PASSWORD
28-
os.environ["POSTGRES_SSL"] = POSTGRES_SSL
29-
os.environ["POSTGRESQL_DATABASE_URL"] = POSTGRESQL_DATABASE_URL
30-
os.environ["RUNNING_IN_PRODUCTION"] = "False"
31-
os.environ["OPENAI_API_KEY"] = "fakekey"
24+
def monkeypatch_session():
25+
with pytest.MonkeyPatch.context() as monkeypatch_session:
26+
yield monkeypatch_session
3227

3328

3429
@pytest.fixture(scope="session")
35-
def mock_azure_credential():
36-
"""Mock the Azure credential for testing."""
37-
with patch("azure.identity.DefaultAzureCredential", return_value=None):
30+
def mock_session_env(monkeypatch_session):
31+
"""Mock the environment variables for testing."""
32+
with mock.patch.dict(os.environ, clear=True):
33+
# Database
34+
monkeypatch_session.setenv("POSTGRES_HOST", POSTGRES_HOST)
35+
monkeypatch_session.setenv("POSTGRES_USERNAME", POSTGRES_USERNAME)
36+
monkeypatch_session.setenv("POSTGRES_DATABASE", POSTGRES_DATABASE)
37+
monkeypatch_session.setenv("POSTGRES_PASSWORD", POSTGRES_PASSWORD)
38+
monkeypatch_session.setenv("POSTGRES_SSL", POSTGRES_SSL)
39+
monkeypatch_session.setenv("POSTGRESQL_DATABASE_URL", POSTGRESQL_DATABASE_URL)
40+
monkeypatch_session.setenv("RUNNING_IN_PRODUCTION", "False")
41+
# Azure Subscription
42+
monkeypatch_session.setenv("AZURE_SUBSCRIPTION_ID", "test-storage-subid")
43+
# OpenAI
44+
monkeypatch_session.setenv("AZURE_OPENAI_CHATGPT_MODEL", "gpt-35-turbo")
45+
monkeypatch_session.setenv("OPENAI_API_KEY", "fakekey")
46+
# Allowed Origin
47+
monkeypatch_session.setenv("ALLOWED_ORIGIN", "https://frontend.com")
48+
49+
if os.getenv("AZURE_USE_AUTHENTICATION") is not None:
50+
monkeypatch_session.delenv("AZURE_USE_AUTHENTICATION")
3851
yield
3952

4053

4154
@pytest.fixture(scope="session")
42-
def app(setup_env, mock_azure_credential):
55+
def app(mock_session_env):
4356
"""Create a FastAPI app."""
4457
if not Path("src/static/").exists():
4558
pytest.skip("Please generate frontend files first!")
46-
return create_app(is_testing=True)
59+
return create_app(testing=True)
60+
61+
62+
@pytest.fixture(scope="function")
63+
def mock_default_azure_credential(mock_session_env):
64+
"""Mock the Azure credential for testing."""
65+
with mock.patch("azure.identity.DefaultAzureCredential") as mock_default_azure_credential:
66+
mock_default_azure_credential.return_value = MockAzureCredential()
67+
yield mock_default_azure_credential
4768

4869

4970
@pytest.fixture(scope="function")
50-
def test_client(app):
71+
def test_client(monkeypatch, app, mock_default_azure_credential):
5172
"""Create a test client."""
5273

5374
with TestClient(app) as test_client:

tests/mocks.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import json
2+
from collections import namedtuple
3+
4+
import openai.types
5+
from azure.core.credentials_async import AsyncTokenCredential
6+
7+
MOCK_EMBEDDING_DIMENSIONS = 1536
8+
MOCK_EMBEDDING_MODEL_NAME = "text-embedding-ada-002"
9+
10+
MockToken = namedtuple("MockToken", ["token", "expires_on", "value"])
11+
12+
13+
class MockAzureCredential(AsyncTokenCredential):
14+
async def get_token(self, uri):
15+
return MockToken("", 9999999999, "")
16+
17+
18+
class MockAzureCredentialExpired(AsyncTokenCredential):
19+
def __init__(self):
20+
self.access_number = 0
21+
22+
async def get_token(self, uri):
23+
self.access_number += 1
24+
if self.access_number == 1:
25+
return MockToken("", 0, "")
26+
else:
27+
return MockToken("", 9999999999, "")
28+
29+
30+
class MockAsyncPageIterator:
31+
def __init__(self, data):
32+
self.data = data
33+
34+
def __aiter__(self):
35+
return self
36+
37+
async def __anext__(self):
38+
if not self.data:
39+
raise StopAsyncIteration
40+
return self.data.pop(0) # This should be a list of dictionaries.
41+
42+
43+
class MockCaption:
44+
def __init__(self, text, highlights=None, additional_properties=None):
45+
self.text = text
46+
self.highlights = highlights or []
47+
self.additional_properties = additional_properties or {}
48+
49+
50+
class MockResponse:
51+
def __init__(self, text, status):
52+
self.text = text
53+
self.status = status
54+
55+
async def text(self):
56+
return self._text
57+
58+
async def __aexit__(self, exc_type, exc, tb):
59+
pass
60+
61+
async def __aenter__(self):
62+
return self
63+
64+
async def json(self):
65+
return json.loads(self.text)
66+
67+
68+
class MockEmbeddingsClient:
69+
def __init__(self, create_embedding_response: openai.types.CreateEmbeddingResponse):
70+
self.create_embedding_response = create_embedding_response
71+
72+
async def create(self, *args, **kwargs) -> openai.types.CreateEmbeddingResponse:
73+
return self.create_embedding_response
74+
75+
76+
class MockClient:
77+
def __init__(self, embeddings_client):
78+
self.embeddings = embeddings_client
79+
80+
81+
def mock_computervision_response():
82+
return MockResponse(
83+
status=200,
84+
text=json.dumps(
85+
{
86+
"vector": [
87+
0.011925711,
88+
0.023533698,
89+
0.010133852,
90+
0.0063544377,
91+
-0.00038590943,
92+
0.0013952175,
93+
0.009054946,
94+
-0.033573493,
95+
-0.002028305,
96+
],
97+
"modelVersion": "2022-04-11",
98+
}
99+
),
100+
)
101+
102+
103+
class MockSynthesisResult:
104+
def __init__(self, result):
105+
self.__result = result
106+
107+
def get(self):
108+
return self.__result

0 commit comments

Comments
 (0)