Skip to content

Commit d2350a1

Browse files
authored
test: add more tests in the agent (#1572)
1 parent 4ca228f commit d2350a1

File tree

5 files changed

+468
-3
lines changed

5 files changed

+468
-3
lines changed

tests/test_memory.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from pandasai.helpers.memory import Memory
2+
3+
4+
def test_to_json_empty_memory():
5+
memory = Memory()
6+
assert memory.to_json() == []
7+
8+
9+
def test_to_json_with_messages():
10+
memory = Memory()
11+
12+
# Add test messages
13+
memory.add("Hello", is_user=True)
14+
memory.add("Hi there!", is_user=False)
15+
memory.add("How are you?", is_user=True)
16+
17+
expected_json = [
18+
{"role": "user", "message": "Hello"},
19+
{"role": "assistant", "message": "Hi there!"},
20+
{"role": "user", "message": "How are you?"},
21+
]
22+
23+
assert memory.to_json() == expected_json
24+
25+
26+
def test_to_json_message_order():
27+
memory = Memory()
28+
29+
# Add messages in specific order
30+
messages = [("Message 1", True), ("Message 2", False), ("Message 3", True)]
31+
32+
for msg, is_user in messages:
33+
memory.add(msg, is_user=is_user)
34+
35+
result = memory.to_json()
36+
37+
# Verify order is preserved
38+
assert len(result) == 3
39+
assert result[0]["message"] == "Message 1"
40+
assert result[1]["message"] == "Message 2"
41+
assert result[2]["message"] == "Message 3"
42+
43+
44+
def test_to_openai_messages_empty():
45+
memory = Memory()
46+
assert memory.to_openai_messages() == []
47+
48+
49+
def test_to_openai_messages_with_agent_description():
50+
memory = Memory(agent_description="I am a helpful assistant")
51+
memory.add("Hello", is_user=True)
52+
memory.add("Hi there!", is_user=False)
53+
54+
expected_messages = [
55+
{"role": "system", "content": "I am a helpful assistant"},
56+
{"role": "user", "content": "Hello"},
57+
{"role": "assistant", "content": "Hi there!"},
58+
]
59+
60+
assert memory.to_openai_messages() == expected_messages
61+
62+
63+
def test_to_openai_messages_without_agent_description():
64+
memory = Memory()
65+
memory.add("Hello", is_user=True)
66+
memory.add("Hi there!", is_user=False)
67+
memory.add("How are you?", is_user=True)
68+
69+
expected_messages = [
70+
{"role": "user", "content": "Hello"},
71+
{"role": "assistant", "content": "Hi there!"},
72+
{"role": "user", "content": "How are you?"},
73+
]
74+
75+
assert memory.to_openai_messages() == expected_messages

tests/unit_tests/agent/test_agent.py

Lines changed: 107 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import os
22
from typing import Optional
3-
from unittest.mock import MagicMock, Mock, mock_open, patch
3+
from unittest.mock import ANY, MagicMock, Mock, mock_open, patch
44

55
import pandas as pd
66
import pytest
77

88
from pandasai import DatasetLoader, VirtualDataFrame
99
from pandasai.agent.base import Agent
1010
from pandasai.config import Config, ConfigManager
11+
from pandasai.core.response.error import ErrorResponse
1112
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
1213
from pandasai.dataframe.base import DataFrame
13-
from pandasai.exceptions import CodeExecutionError
14+
from pandasai.exceptions import CodeExecutionError, InvalidLLMOutputType
1415
from pandasai.llm.fake import FakeLLM
1516

1617

@@ -466,3 +467,107 @@ def test_execute_sql_query_error_no_dataframe(self, agent):
466467

467468
with pytest.raises(ValueError, match="No DataFrames available"):
468469
agent._execute_sql_query(query)
470+
471+
def test_process_query(self, agent, config):
472+
"""Test the _process_query method with successful execution"""
473+
query = "What is the average age?"
474+
output_type = "number"
475+
476+
# Mock the necessary methods
477+
agent.generate_code = Mock(return_value="result = df['age'].mean()")
478+
agent.execute_with_retries = Mock(return_value=30.5)
479+
agent._state.config.enable_cache = True
480+
agent._state.cache = Mock()
481+
482+
# Execute the query
483+
result = agent._process_query(query, output_type)
484+
485+
# Verify the result
486+
assert result == 30.5
487+
488+
# Verify method calls
489+
agent.generate_code.assert_called_once()
490+
agent.execute_with_retries.assert_called_once_with("result = df['age'].mean()")
491+
agent._state.cache.set.assert_called_once()
492+
493+
def test_process_query_execution_error(self, agent, config):
494+
"""Test the _process_query method with execution error"""
495+
query = "What is the invalid operation?"
496+
497+
# Mock methods to simulate error
498+
agent.generate_code = Mock(return_value="invalid_code")
499+
agent.execute_with_retries = Mock(
500+
side_effect=CodeExecutionError("Execution failed")
501+
)
502+
agent._handle_exception = Mock(return_value="Error handled")
503+
504+
# Execute the query
505+
result = agent._process_query(query)
506+
507+
# Verify error handling
508+
assert result == "Error handled"
509+
agent._handle_exception.assert_called_once_with("invalid_code")
510+
511+
def test_regenerate_code_after_invalid_llm_output_error(self, agent):
512+
"""Test code regeneration with InvalidLLMOutputType error"""
513+
from pandasai.exceptions import InvalidLLMOutputType
514+
515+
code = "test code"
516+
error = InvalidLLMOutputType("Invalid output type")
517+
518+
with patch(
519+
"pandasai.agent.base.get_correct_output_type_error_prompt"
520+
) as mock_prompt:
521+
mock_prompt.return_value = "corrected prompt"
522+
agent._code_generator.generate_code = MagicMock(return_value="new code")
523+
524+
result = agent._regenerate_code_after_error(code, error)
525+
526+
mock_prompt.assert_called_once_with(agent._state, code, ANY)
527+
agent._code_generator.generate_code.assert_called_once_with(
528+
"corrected prompt"
529+
)
530+
assert result == "new code"
531+
532+
def test_regenerate_code_after_other_error(self, agent):
533+
"""Test code regeneration with non-InvalidLLMOutputType error"""
534+
code = "test code"
535+
error = ValueError("Some other error")
536+
537+
with patch(
538+
"pandasai.agent.base.get_correct_error_prompt_for_sql"
539+
) as mock_prompt:
540+
mock_prompt.return_value = "sql error prompt"
541+
agent._code_generator.generate_code = MagicMock(return_value="new code")
542+
543+
result = agent._regenerate_code_after_error(code, error)
544+
545+
mock_prompt.assert_called_once_with(agent._state, code, ANY)
546+
agent._code_generator.generate_code.assert_called_once_with(
547+
"sql error prompt"
548+
)
549+
assert result == "new code"
550+
551+
def test_handle_exception(self, agent):
552+
"""Test that _handle_exception properly formats and logs exceptions"""
553+
test_code = "print(1/0)" # Code that will raise a ZeroDivisionError
554+
555+
# Mock the logger to verify it's called
556+
mock_logger = MagicMock()
557+
agent._state.logger = mock_logger
558+
559+
# Create an actual exception to handle
560+
try:
561+
exec(test_code)
562+
except:
563+
# Call the method
564+
result = agent._handle_exception(test_code)
565+
566+
# Verify the result is an ErrorResponse
567+
assert isinstance(result, ErrorResponse)
568+
assert result.last_code_executed == test_code
569+
assert "ZeroDivisionError" in result.error
570+
571+
# Verify the error was logged
572+
mock_logger.log.assert_called_once()
573+
assert "Processing failed with error" in mock_logger.log.call_args[0][0]
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import os
2+
from io import BytesIO
3+
from unittest.mock import Mock, mock_open, patch
4+
from zipfile import ZipFile
5+
6+
import pandas as pd
7+
import pytest
8+
9+
from pandasai.data_loader.semantic_layer_schema import (
10+
Column,
11+
SemanticLayerSchema,
12+
Source,
13+
)
14+
from pandasai.dataframe.base import DataFrame
15+
from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError
16+
17+
18+
@pytest.fixture
19+
def mock_env(monkeypatch):
20+
monkeypatch.setenv("PANDABI_API_KEY", "test_api_key")
21+
22+
23+
@pytest.fixture
24+
def sample_df():
25+
return pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]})
26+
27+
28+
@pytest.fixture
29+
def mock_zip_content():
30+
zip_buffer = BytesIO()
31+
with ZipFile(zip_buffer, "w") as zip_file:
32+
zip_file.writestr("test.csv", "col1,col2\n1,a\n2,b\n3,c")
33+
return zip_buffer.getvalue()
34+
35+
36+
@pytest.fixture
37+
def mock_schema():
38+
return SemanticLayerSchema(
39+
name="test_schema",
40+
source=Source(type="parquet", path="data.parquet", table="test_table"),
41+
columns=[
42+
Column(name="col1", type="integer"),
43+
Column(name="col2", type="string"),
44+
],
45+
)
46+
47+
48+
def test_pull_success(mock_env, sample_df, mock_zip_content, mock_schema, tmp_path):
49+
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session, patch(
50+
"pandasai.dataframe.base.find_project_root"
51+
) as mock_root, patch(
52+
"pandasai.DatasetLoader.create_loader_from_path"
53+
) as mock_loader, patch("builtins.open", mock_open()) as mock_file:
54+
# Setup mocks
55+
mock_response = Mock()
56+
mock_response.status_code = 200
57+
mock_response.content = mock_zip_content
58+
mock_session.return_value.get.return_value = mock_response
59+
mock_root.return_value = str(tmp_path)
60+
61+
mock_loader_instance = Mock()
62+
mock_loader_instance.load.return_value = DataFrame(
63+
sample_df, schema=mock_schema
64+
)
65+
mock_loader.return_value = mock_loader_instance
66+
67+
# Create DataFrame instance and call pull
68+
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
69+
df.pull()
70+
71+
# Verify API call
72+
mock_session.return_value.get.assert_called_once_with(
73+
"/datasets/pull",
74+
headers={
75+
"accept": "application/json",
76+
"x-authorization": "Bearer test_api_key",
77+
},
78+
params={"path": "test/path"},
79+
)
80+
81+
# Verify file operations
82+
assert mock_file.call_count > 0
83+
84+
85+
def test_pull_missing_api_key(sample_df, mock_schema):
86+
with patch("os.environ.get") as mock_env_get:
87+
mock_env_get.return_value = None
88+
with pytest.raises(PandaAIApiKeyError):
89+
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
90+
df.pull()
91+
92+
93+
def test_pull_api_error(mock_env, sample_df, mock_schema):
94+
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session:
95+
mock_response = Mock()
96+
mock_response.status_code = 404
97+
mock_session.return_value.get.return_value = mock_response
98+
99+
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
100+
with pytest.raises(DatasetNotFound, match="Remote dataset not found to pull!"):
101+
df.pull()
102+
103+
104+
def test_pull_file_exists(mock_env, sample_df, mock_zip_content, mock_schema, tmp_path):
105+
with patch("pandasai.dataframe.base.get_pandaai_session") as mock_session, patch(
106+
"pandasai.dataframe.base.find_project_root"
107+
) as mock_root, patch(
108+
"pandasai.DatasetLoader.create_loader_from_path"
109+
) as mock_loader, patch("builtins.open", mock_open()) as mock_file, patch(
110+
"os.path.exists"
111+
) as mock_exists, patch("os.makedirs") as mock_makedirs:
112+
# Setup mocks
113+
mock_response = Mock()
114+
mock_response.status_code = 200
115+
mock_response.content = mock_zip_content
116+
mock_session.return_value.get.return_value = mock_response
117+
mock_root.return_value = str(tmp_path)
118+
mock_exists.return_value = True
119+
120+
mock_loader_instance = Mock()
121+
mock_loader_instance.load.return_value = DataFrame(
122+
sample_df, schema=mock_schema
123+
)
124+
mock_loader.return_value = mock_loader_instance
125+
126+
# Create DataFrame instance and call pull
127+
df = DataFrame(sample_df, path="test/path", schema=mock_schema)
128+
df.pull()
129+
130+
# Verify directory creation
131+
mock_makedirs.assert_called_with(
132+
os.path.dirname(
133+
os.path.join(str(tmp_path), "datasets", "test/path", "test.csv")
134+
),
135+
exist_ok=True,
136+
)

0 commit comments

Comments
 (0)