Skip to content

Commit 62d7490

Browse files
scaliseraoulscaliseraoul-sinaptikgventuri
authored
fix(pandasai-sql): dropping support for sqlite (#1558)
* fix(pandasai-sql): dropping support for sqlite * fix(): standardize name in schema * fix(): dropping name and description from dataframe * fix: print leftover * refactor(Loader): refactoring loader * refactor(Loader): avoid creating temporary pandas dataframe --------- Co-authored-by: scaliseraoul-sinaptik <raoul@sinaptik.ai> Co-authored-by: Gabriele Venturi <lele.venturi@gmail.com>
1 parent d267095 commit 62d7490

28 files changed

+985
-1409
lines changed

extensions/connectors/sql/pandasai_sql/__init__.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22

33
import pandas as pd
44

5-
from pandasai.data_loader.semantic_layer_schema import (
6-
SQLConnectionConfig,
7-
SqliteConnectionConfig,
8-
)
5+
from pandasai.data_loader.semantic_layer_schema import SQLConnectionConfig
96

107

118
def load_from_mysql(
@@ -38,15 +35,6 @@ def load_from_postgres(
3835
return pd.read_sql(query, conn, params=params)
3936

4037

41-
def load_from_sqlite(
42-
connection_info: SqliteConnectionConfig, query: str, params: Optional[list] = None
43-
):
44-
import sqlite3
45-
46-
conn = sqlite3.connect(connection_info.file_path)
47-
return pd.read_sql(query, conn, params=params)
48-
49-
5038
def load_from_cockroachdb(
5139
connection_info: SQLConnectionConfig, query: str, params: Optional[list] = None
5240
):

extensions/connectors/sql/poetry.lock

Lines changed: 121 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

extensions/connectors/sql/tests/test_sql.py

Lines changed: 1 addition & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,9 @@
88
load_from_cockroachdb,
99
load_from_mysql,
1010
load_from_postgres,
11-
load_from_sqlite,
1211
)
1312

14-
from pandasai.data_loader.semantic_layer_schema import (
15-
SQLConnectionConfig,
16-
SqliteConnectionConfig,
17-
)
13+
from pandasai.data_loader.semantic_layer_schema import SQLConnectionConfig
1814

1915

2016
class TestDatabaseLoader(unittest.TestCase):
@@ -95,32 +91,6 @@ def test_load_from_postgres(self, mock_read_sql, mock_psycopg2_connect):
9591
self.assertIsInstance(result, pd.DataFrame)
9692
self.assertEqual(result.shape, (2, 2))
9793

98-
@patch("sqlite3.connect")
99-
@patch("pandas.read_sql")
100-
def test_load_from_sqlite(self, mock_read_sql, mock_sqlite3_connect):
101-
# Setup the mock return values
102-
mock_conn = MagicMock()
103-
mock_sqlite3_connect.return_value = mock_conn
104-
mock_read_sql.return_value = pd.DataFrame(
105-
{"column1": [9, 10], "column2": [11, 12]}
106-
)
107-
108-
# Test data
109-
connection_info = {"file_path": "test_db.sqlite"}
110-
query = "SELECT * FROM test_table"
111-
112-
connection_config = SqliteConnectionConfig(**connection_info)
113-
114-
result = load_from_sqlite(connection_config, query)
115-
116-
# Assert that the connection is made and SQL query is executed
117-
mock_sqlite3_connect.assert_called_once_with("test_db.sqlite")
118-
mock_read_sql.assert_called_once_with(query, mock_conn, params=None)
119-
120-
# Assert the result is a DataFrame
121-
self.assertIsInstance(result, pd.DataFrame)
122-
self.assertEqual(result.shape, (2, 2))
123-
12494
@patch("psycopg2.connect")
12595
@patch("pandas.read_sql")
12696
def test_load_from_cockroachdb(self, mock_read_sql, mock_postgresql_connect):
@@ -229,29 +199,6 @@ def test_load_from_postgres_with_params(self, mock_read_sql, mock_psycopg2_conne
229199
self.assertIsInstance(result, pd.DataFrame)
230200
self.assertEqual(result.shape, (2, 2))
231201

232-
@patch("sqlite3.connect")
233-
@patch("pandas.read_sql")
234-
def test_load_from_sqlite_with_params(self, mock_read_sql, mock_sqlite3_connect):
235-
mock_conn = MagicMock()
236-
mock_sqlite3_connect.return_value = mock_conn
237-
mock_read_sql.return_value = pd.DataFrame(
238-
{"column1": [9, 10], "column2": [11, 12]}
239-
)
240-
241-
connection_info = {"file_path": "test_db.sqlite"}
242-
query = "SELECT * FROM test_table WHERE age > ?"
243-
query_params = [30]
244-
245-
connection_config = SqliteConnectionConfig(**connection_info)
246-
247-
result = load_from_sqlite(connection_config, query, query_params)
248-
249-
mock_sqlite3_connect.assert_called_once_with("test_db.sqlite")
250-
mock_read_sql.assert_called_once_with(query, mock_conn, params=query_params)
251-
252-
self.assertIsInstance(result, pd.DataFrame)
253-
self.assertEqual(result.shape, (2, 2))
254-
255202
@patch("psycopg2.connect")
256203
@patch("pandas.read_sql")
257204
def test_load_from_cockroachdb_with_params(

pandasai/__init__.py

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
import os
7-
import re
87
from io import BytesIO
98
from typing import List, Optional, Union
109
from zipfile import ZipFile
@@ -27,6 +26,7 @@
2726
from .constants import LOCAL_SOURCE_TYPES, SQL_SOURCE_TYPES
2827
from .core.cache import Cache
2928
from .data_loader.loader import DatasetLoader
29+
from .data_loader.query_builder import QueryBuilder
3030
from .data_loader.semantic_layer_schema import (
3131
Column,
3232
)
@@ -39,11 +39,11 @@
3939
def create(
4040
path: str,
4141
df: Optional[DataFrame] = None,
42-
name: Optional[str] = None,
4342
description: Optional[str] = None,
4443
columns: Optional[List[dict]] = None,
4544
source: Optional[dict] = None,
4645
relations: Optional[List[dict]] = None,
46+
view: bool = False,
4747
) -> Union[DataFrame, VirtualDataFrame]:
4848
"""
4949
Creates a new dataset at the specified path with optional metadata, schema,
@@ -85,7 +85,6 @@ def create(
8585
>>> create(
8686
... path="my-org/my-dataset",
8787
... df=my_dataframe,
88-
... name="My Dataset",
8988
... description="This is a sample dataset.",
9089
... columns=[
9190
... {"name": "id", "type": "integer", "description": "Primary key"},
@@ -103,54 +102,46 @@ def create(
103102
find_project_root(), "datasets", org_name, dataset_name
104103
)
105104

105+
schema_path = os.path.join(str(dataset_directory), "schema.yaml")
106+
parquet_file_path = os.path.join(str(dataset_directory), "data.parquet")
106107
# Check if dataset already exists
107-
if os.path.exists(dataset_directory):
108-
schema_path = os.path.join(dataset_directory, "schema.yaml")
109-
if os.path.exists(schema_path):
110-
raise ValueError(f"Dataset already exists at path: {path}")
108+
if os.path.exists(dataset_directory) and os.path.exists(schema_path):
109+
raise ValueError(f"Dataset already exists at path: {path}")
111110

112111
os.makedirs(dataset_directory, exist_ok=True)
113112

114-
# Save schema to yaml
115-
schema_path = os.path.join(dataset_directory, "schema.yaml")
116-
117-
if df is None and source is None:
118-
raise InvalidConfigError("Please provide either a DataFrame or a source")
113+
if df is None and source is None and not view:
114+
raise InvalidConfigError(
115+
"Please provide either a DataFrame, a Source or a View"
116+
)
119117

120118
if df is not None:
121119
schema = df.schema
122-
df.to_parquet(os.path.join(dataset_directory, "data.parquet"), index=False)
123-
elif source.get("type") == "sqlite" and source.get("table"):
124-
schema = SemanticLayerSchema(name=source.get("table"), source=Source(**source))
125-
df = _dataset_loader.load(schema=schema)
126-
df.to_parquet(os.path.join(dataset_directory, "data.parquet"), index=False)
127-
elif source.get("table"):
128-
schema = SemanticLayerSchema(name=source.get("table"), source=Source(**source))
129-
df = _dataset_loader.load(schema=schema)
130-
elif source.get("view"):
131-
name = name or dataset_name
120+
schema.name = sanitize_sql_table_name(dataset_name)
121+
df.to_parquet(parquet_file_path, index=False)
122+
elif view:
132123
_relation = [Relation(**relation) for relation in relations or ()]
133-
schema = SemanticLayerSchema(
134-
name=name, source=Source(**source), relations=_relation
124+
schema: SemanticLayerSchema = SemanticLayerSchema(
125+
name=sanitize_sql_table_name(dataset_name), relations=_relation, view=True
135126
)
136-
df = _dataset_loader.load(schema=schema)
127+
elif source.get("table"):
128+
schema: SemanticLayerSchema = SemanticLayerSchema(
129+
name=sanitize_sql_table_name(dataset_name), source=Source(**source)
130+
)
131+
else:
132+
raise InvalidConfigError("Unable to create schema with the provided params")
137133

138-
schema.name = sanitize_sql_table_name(name or schema.name)
139134
schema.description = description or schema.description
140135
if columns:
141136
schema.columns = [Column(**column) for column in columns]
142-
elif df is not None:
143-
schema.columns = [
144-
Column(name=str(name), type=DataFrame.get_column_type(dtype))
145-
for name, dtype in df.dtypes.items()
146-
]
147137

148138
with open(schema_path, "w") as yml_file:
149139
yml_file.write(schema.to_yaml())
150140

151141
print(f"Dataset saved successfully to path: {dataset_directory}")
152142

153-
return _dataset_loader.load(path)
143+
loader = DatasetLoader.create_loader_from_schema(schema, path)
144+
return loader.load()
154145

155146

156147
# Global variable to store the current agent
@@ -206,9 +197,6 @@ def follow_up(query: str):
206197
return _current_agent.follow_up(query)
207198

208199

209-
_dataset_loader = DatasetLoader()
210-
211-
212200
def load(dataset_path: str) -> DataFrame:
213201
"""
214202
Load data based on the provided dataset path.
@@ -223,7 +211,6 @@ def load(dataset_path: str) -> DataFrame:
223211
if len(path_parts) != 2:
224212
raise ValueError("The path must be in the format 'organization/dataset'.")
225213

226-
global _dataset_loader
227214
dataset_full_path = os.path.join(find_project_root(), "datasets", dataset_path)
228215
if not os.path.exists(dataset_full_path):
229216
api_key = os.environ.get("PANDABI_API_KEY", None)
@@ -244,13 +231,14 @@ def load(dataset_path: str) -> DataFrame:
244231
with ZipFile(BytesIO(file_data.content)) as zip_file:
245232
zip_file.extractall(dataset_full_path)
246233

247-
return _dataset_loader.load(dataset_path)
234+
loader = DatasetLoader.create_loader_from_path(dataset_path)
235+
return loader.load()
248236

249237

250238
def read_csv(filepath: str) -> DataFrame:
251239
data = pd.read_csv(filepath)
252-
name = f"table_{sanitize_sql_table_name(filepath)}"
253-
return DataFrame(data, name=name)
240+
table = f"table_{sanitize_sql_table_name(filepath)}"
241+
return DataFrame(data, _table_name=table)
254242

255243

256244
__all__ = [

pandasai/agent/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _execute_local_sql_query(self, query: str) -> pd.DataFrame:
123123
with duckdb.connect() as con:
124124
# Register all DataFrames in the state
125125
for df in self._state.dfs:
126-
con.register(df.name, df)
126+
con.register(df.schema.source.table, df)
127127

128128
# Execute the query and fetch the result as a pandas DataFrame
129129
result = con.sql(query).df()

pandasai/constants.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,29 +29,26 @@
2929
"mysql": "pandasai_sql",
3030
"postgres": "pandasai_sql",
3131
"cockroachdb": "pandasai_sql",
32-
"sqlite": "pandasai_sql",
3332
"yahoo_finance": "pandasai_yfinance",
3433
"bigquery": "pandasai_bigquery",
3534
"snowflake": "pandasai_snowflake",
3635
"databricks": "pandasai_databricks",
3736
"oracle": "pandasai_oracle",
3837
}
3938

40-
LOCAL_SOURCE_TYPES = ["csv", "parquet", "sqlite"]
39+
LOCAL_SOURCE_TYPES = ["csv", "parquet"]
4140
REMOTE_SOURCE_TYPES = [
4241
"mysql",
4342
"postgres",
4443
"cockroachdb",
45-
"sqlite",
4644
"data",
4745
"yahoo_finance",
4846
"bigquery",
4947
"snowflake",
5048
"databricks",
5149
"oracle",
5250
]
53-
SQL_SOURCE_TYPES = ["mysql", "postgres", "cockroachdb", "oracle", "sqlite"]
54-
51+
SQL_SOURCE_TYPES = ["mysql", "postgres", "cockroachdb", "oracle"]
5552
VALID_COLUMN_TYPES = ["string", "integer", "float", "datetime", "boolean"]
5653

5754
VALID_TRANSFORMATION_TYPES = [

pandasai/core/code_generation/code_cleaning.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _check_direct_sql_func_def_exists(self, node: ast.AST) -> bool:
2929
return isinstance(node, ast.FunctionDef) and node.name == "execute_sql_query"
3030

3131
def _replace_table_names(
32-
self, sql_query: str, table_names: list, allowed_table_names: list
32+
self, sql_query: str, table_names: list, allowed_table_names: dict
3333
) -> str:
3434
"""
3535
Replace table names in the SQL query with case-sensitive or authorized table names.
@@ -54,8 +54,11 @@ def _clean_sql_query(self, sql_query: str) -> str:
5454
"""
5555
sql_query = sql_query.rstrip(";")
5656
table_names = extract_table_names(sql_query)
57-
allowed_table_names = {df.name: df.name for df in self.context.dfs} | {
58-
f'"{df.name}"': df.name for df in self.context.dfs
57+
allowed_table_names = {
58+
df.schema.source.table: df.schema.source.table for df in self.context.dfs
59+
} | {
60+
f'"{df.schema.source.table}"': df.schema.source.table
61+
for df in self.context.dfs
5962
}
6063
return self._replace_table_names(sql_query, table_names, allowed_table_names)
6164

0 commit comments

Comments
 (0)