Skip to content

Commit 847757a

Browse files
feature(SqlLoader): transformations in SqlLoader (#1569)
Co-authored-by: scaliseraoul-sinaptik <raoul@sinaptik.ai>
1 parent 62d7490 commit 847757a

File tree

7 files changed

+166
-129
lines changed

7 files changed

+166
-129
lines changed

pandasai/data_loader/loader.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
import pandas as pd
34
import yaml
45

56
from pandasai.dataframe.base import DataFrame
@@ -12,6 +13,7 @@
1213
)
1314
from .query_builder import QueryBuilder
1415
from .semantic_layer_schema import SemanticLayerSchema
16+
from .transformation_manager import TransformationManager
1517
from .view_query_builder import ViewQueryBuilder
1618

1719

@@ -72,16 +74,12 @@ def load(self) -> DataFrame:
7274
"""
7375
raise MethodNotImplementedError("Loader not instantiated")
7476

75-
def _build_dataset(
76-
self, schema: SemanticLayerSchema, dataset_path: str
77-
) -> DataFrame:
78-
self.schema = schema
79-
self.dataset_path = dataset_path
80-
is_view = schema.view
77+
def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
78+
if not self.schema.transformations:
79+
return df
8180

82-
self.query_builder = (
83-
ViewQueryBuilder(schema) if is_view else QueryBuilder(schema)
84-
)
81+
transformation_manager = TransformationManager(df)
82+
return transformation_manager.apply_transformations(self.schema.transformations)
8583

8684
def _get_abs_dataset_path(self):
8785
return os.path.join(find_project_root(), "datasets", self.dataset_path)

pandasai/data_loader/local_loader.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,3 @@ def _filter_columns(self, df: pd.DataFrame) -> pd.DataFrame:
6969
df_columns = df.columns.tolist()
7070
columns_to_keep = [col for col in df_columns if col in schema_columns]
7171
return df[columns_to_keep]
72-
73-
def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
74-
if not self.schema.transformations:
75-
return df
76-
77-
transformation_manager = TransformationManager(df)
78-
return transformation_manager.apply_transformations(self.schema.transformations)

pandasai/data_loader/sql_loader.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def __init__(self, schema: SemanticLayerSchema, dataset_path: str):
2424
self.query_builder: QueryBuilder = QueryBuilder(schema)
2525

2626
def load(self) -> VirtualDataFrame:
27-
self.query_builder = QueryBuilder(self.schema)
2827
return VirtualDataFrame(
2928
schema=self.schema,
3029
data_loader=SQLDatasetLoader(self.schema, self.dataset_path),
@@ -37,9 +36,11 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
3736

3837
formatted_query = self.query_builder.format_query(query)
3938
load_function = self._get_loader_function(source_type)
40-
4139
try:
42-
return load_function(connection_info, formatted_query, params)
40+
dataframe: pd.DataFrame = load_function(
41+
connection_info, formatted_query, params
42+
)
43+
return self._apply_transformations(dataframe)
4344
except Exception as e:
4445
raise RuntimeError(
4546
f"Failed to execute query for '{source_type}' with: {formatted_query}"

pandasai/data_loader/transformation_manager.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Any, List, Optional, Union
22

3-
import numpy as np
43
import pandas as pd
54

65
from ..exceptions import UnsupportedTransformation
6+
from .semantic_layer_schema import Transformation
77

88

99
class TransformationManager:
@@ -268,12 +268,12 @@ def format_date(self, column: str, date_format: str) -> "TransformationManager":
268268
TransformationManager: Self for method chaining
269269
270270
Example:
271-
>>> df = pd.DataFrame({"date": ["2024-01-01 12:30:45"]})
271+
>>> df = pd.DataFrame({"date": ["2025-01-01 12:30:45"]})
272272
>>> manager = TransformationManager(df)
273273
>>> result = manager.format_date("date", "%Y-%m-%d").df
274274
>>> print(result)
275275
date
276-
0 2024-01-01
276+
0 2025-01-01
277277
"""
278278
self.df[column] = self.df[column].dt.strftime(date_format)
279279
return self
@@ -307,28 +307,28 @@ def to_numeric(
307307
return self
308308

309309
def to_datetime(
310-
self, column: str, format: Optional[str] = None, errors: str = "coerce"
310+
self, column: str, _format: Optional[str] = None, errors: str = "coerce"
311311
) -> "TransformationManager":
312312
"""Convert values in a column to datetime type.
313313
314314
Args:
315315
column (str): The column to transform
316-
format (Optional[str]): Expected date format of the input
316+
_format (Optional[str]): Expected date format of the input
317317
errors (str): How to handle parsing errors
318318
319319
Returns:
320320
TransformationManager: Self for method chaining
321321
322322
Example:
323-
>>> df = pd.DataFrame({"date": ["2024-01-01", "invalid"]})
323+
>>> df = pd.DataFrame({"date": ["2025-01-01", "invalid"]})
324324
>>> manager = TransformationManager(df)
325325
>>> result = manager.to_datetime("date", errors="coerce").df
326326
>>> print(result)
327327
date
328-
0 2024-01-01
328+
0 2025-01-01
329329
1 NaT
330330
"""
331-
self.df[column] = pd.to_datetime(self.df[column], format=format, errors=errors)
331+
self.df[column] = pd.to_datetime(self.df[column], format=_format, errors=errors)
332332
return self
333333

334334
def fill_na(self, column: str, value: Any) -> "TransformationManager":
@@ -884,27 +884,20 @@ def rename(self, column: str, new_name: str) -> "TransformationManager":
884884
return self
885885

886886
def apply_transformations(
887-
self, transformations: Optional[List[dict]] = None
887+
self, transformations: List[Transformation]
888888
) -> pd.DataFrame:
889889
"""Apply a list of transformations to the DataFrame.
890890
891891
Args:
892-
transformations (Optional[List[dict]]): List of transformation configurations
892+
transformations List[Transformation]: List of transformation configurations
893893
894894
Returns:
895895
pd.DataFrame: The transformed DataFrame
896896
"""
897-
if not transformations:
898-
return self.df
899897

900898
for transformation in transformations:
901-
# Handle both dict and object transformations
902-
if isinstance(transformation, dict):
903-
transformation_type = transformation["type"]
904-
params = transformation["params"]
905-
else:
906-
transformation_type = transformation.type
907-
params = transformation.params
899+
transformation_type = transformation.type
900+
params = transformation.params
908901

909902
handler = self.transformation_handlers.get(transformation_type)
910903
if not handler:

pandasai/dataframe/virtual_dataframe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, ClassVar
3+
from typing import TYPE_CHECKING, Optional
44

55
import pandas as pd
66

7-
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
87
from pandasai.dataframe.base import DataFrame
98
from pandasai.exceptions import VirtualizationError
109

1110
if TYPE_CHECKING:
12-
from pandasai.data_loader.loader import DatasetLoader
11+
from pandasai.data_loader.sql_loader import SQLDatasetLoader
1312

1413

1514
class VirtualDataFrame(DataFrame):
@@ -25,7 +24,7 @@ class VirtualDataFrame(DataFrame):
2524
]
2625

2726
def __init__(self, *args, **kwargs):
28-
self._loader: DatasetLoader = kwargs.pop("data_loader", None)
27+
self._loader: Optional[SQLDatasetLoader] = kwargs.pop("data_loader", None)
2928
if not self._loader:
3029
raise VirtualizationError("Data loader is required for virtualization!")
3130
self._head = None

tests/unit_tests/data_loader/test_loader.py

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from pandasai.data_loader.loader import DatasetLoader
88
from pandasai.data_loader.local_loader import LocalDatasetLoader
99
from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema
10-
from pandasai.data_loader.sql_loader import SQLDatasetLoader
1110
from pandasai.dataframe.base import DataFrame
1211
from pandasai.exceptions import InvalidDataSourceType
1312

@@ -111,92 +110,6 @@ def test_apply_transformations(self, sample_schema):
111110
assert result.iloc[0]["email"] != "user1@example.com"
112111
assert result.iloc[0]["timestamp"].tzname() == "UTC"
113112

114-
def test_load_mysql_source(self, mysql_schema):
115-
"""Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly."""
116-
with patch("os.path.exists", return_value=True), patch(
117-
"builtins.open", mock_open(read_data=str(mysql_schema.to_yaml()))
118-
), patch(
119-
"pandasai.data_loader.sql_loader.SQLDatasetLoader.execute_query"
120-
) as mock_execute_query:
121-
# Mock the query results
122-
mock_execute_query.return_value = DataFrame(
123-
pd.DataFrame(
124-
{
125-
"email": ["test@example.com"],
126-
"first_name": ["John"],
127-
"timestamp": [pd.Timestamp.now()],
128-
}
129-
)
130-
)
131-
132-
loader = SQLDatasetLoader(mysql_schema, "test/users")
133-
logging.debug("Loading schema from dataset path: %s", loader)
134-
result = loader.load()
135-
136-
# Test that we get a VirtualDataFrame
137-
assert isinstance(result, DataFrame)
138-
assert result.schema == mysql_schema
139-
140-
# Test that load_head() works
141-
head_result = result.head()
142-
assert isinstance(head_result, DataFrame)
143-
assert "email" in head_result.columns
144-
assert "first_name" in head_result.columns
145-
assert "timestamp" in head_result.columns
146-
147-
# Verify the SQL query was executed correctly
148-
mock_execute_query.assert_called_once_with(
149-
"SELECT email, first_name, timestamp FROM users ORDER BY RAND() LIMIT 5"
150-
)
151-
152-
# Test executing a custom query
153-
custom_query = "SELECT email FROM users WHERE first_name = 'John'"
154-
result.execute_sql_query(custom_query)
155-
mock_execute_query.assert_called_with(custom_query)
156-
157-
def test_build_dataset_mysql_schema(self, mysql_schema):
158-
"""Test loading data from a MySQL schema directly and creates a VirtualDataFrame and handles queries correctly."""
159-
with patch("os.path.exists", return_value=True), patch(
160-
"builtins.open", mock_open(read_data=str(mysql_schema.to_yaml()))
161-
), patch(
162-
"pandasai.data_loader.sql_loader.SQLDatasetLoader.execute_query"
163-
) as mock_execute_query:
164-
# Mock the query results
165-
mock_execute_query.return_value = DataFrame(
166-
pd.DataFrame(
167-
{
168-
"email": ["test@example.com"],
169-
"first_name": ["John"],
170-
"timestamp": [pd.Timestamp.now()],
171-
}
172-
)
173-
)
174-
175-
loader = SQLDatasetLoader(mysql_schema, "test/test")
176-
logging.debug("Loading schema from dataset path: %s", loader)
177-
result = loader.load()
178-
179-
# Test that we get a VirtualDataFrame
180-
assert isinstance(result, DataFrame)
181-
assert result.schema == mysql_schema
182-
183-
# Test that load_head() works
184-
head_result = result.head()
185-
assert isinstance(head_result, DataFrame)
186-
assert "email" in head_result.columns
187-
assert "first_name" in head_result.columns
188-
assert "timestamp" in head_result.columns
189-
190-
# Verify the SQL query was executed correctly
191-
mock_execute_query.assert_called_once_with(
192-
"SELECT email, first_name, timestamp FROM users ORDER BY RAND() LIMIT 5"
193-
)
194-
195-
# Test executing a custom query
196-
custom_query = "SELECT email FROM users WHERE first_name = 'John'"
197-
result.execute_sql_query(custom_query)
198-
mock_execute_query.assert_called_with(custom_query)
199-
200113
def test_build_dataset_csv_schema(self, sample_schema):
201114
"""Test loading data from a CSV schema directly and creates a VirtualDataFrame and handles queries correctly."""
202115
with patch("os.path.exists", return_value=True), patch(

0 commit comments

Comments
 (0)