Skip to content

Commit 0c6738b

Browse files
authored
feature(FileManager): adding FileManager to make feasible work with the library in other environment (#1573)
1 parent d2350a1 commit 0c6738b

File tree

16 files changed

+176
-268
lines changed

16 files changed

+176
-268
lines changed

pandasai/__init__.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,17 +98,17 @@ def create(
9898

9999
org_name, dataset_name = get_validated_dataset_path(path)
100100

101-
dataset_directory = os.path.join(
102-
find_project_root(), "datasets", org_name, dataset_name
103-
)
101+
dataset_directory = str(os.path.join(org_name, dataset_name))
104102

105-
schema_path = os.path.join(str(dataset_directory), "schema.yaml")
106-
parquet_file_path = os.path.join(str(dataset_directory), "data.parquet")
103+
schema_path = os.path.join(dataset_directory, "schema.yaml")
104+
parquet_file_path = os.path.join(dataset_directory, "data.parquet")
105+
106+
file_manager = config.get().file_manager
107107
# Check if dataset already exists
108-
if os.path.exists(dataset_directory) and os.path.exists(schema_path):
108+
if file_manager.exists(dataset_directory) and file_manager.exists(schema_path):
109109
raise ValueError(f"Dataset already exists at path: {path}")
110110

111-
os.makedirs(dataset_directory, exist_ok=True)
111+
file_manager.mkdir(dataset_directory)
112112

113113
if df is None and source is None and not view:
114114
raise InvalidConfigError(
@@ -135,8 +135,7 @@ def create(
135135
if columns:
136136
schema.columns = [Column(**column) for column in columns]
137137

138-
with open(schema_path, "w") as yml_file:
139-
yml_file.write(schema.to_yaml())
138+
file_manager.write(schema_path, schema.to_yaml())
140139

141140
print(f"Dataset saved successfully to path: {dataset_directory}")
142141

pandasai/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import os
2+
from abc import ABC, abstractmethod
23
from importlib.util import find_spec
34
from typing import Any, Dict, Optional
45

56
from pydantic import BaseModel, ConfigDict
67

8+
from pandasai.helpers.filemanager import DefaultFileManager, FileManager
79
from pandasai.llm.base import LLM
810

911

@@ -13,6 +15,7 @@ class Config(BaseModel):
1315
enable_cache: bool = True
1416
max_retries: int = 3
1517
llm: Optional[LLM] = None
18+
file_manager: FileManager = DefaultFileManager()
1619

1720
model_config = ConfigDict(arbitrary_types_allowed=True)
1821

pandasai/core/prompts/file_based_prompt.py

Lines changed: 0 additions & 40 deletions
This file was deleted.

pandasai/data_loader/loader.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
from pandasai.dataframe.base import DataFrame
77
from pandasai.exceptions import MethodNotImplementedError
8-
from pandasai.helpers.path import find_project_root
98
from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name
109

10+
from .. import ConfigManager
1111
from ..constants import (
1212
LOCAL_SOURCE_TYPES,
1313
)
@@ -48,21 +48,22 @@ def create_loader_from_path(cls, dataset_path: str) -> "DatasetLoader":
4848
"""
4949
Factory method to create the appropriate loader based on the dataset type.
5050
"""
51-
schema = cls._read_local_schema(dataset_path)
51+
schema = cls._read_schema_file(dataset_path)
5252
return DatasetLoader.create_loader_from_schema(schema, dataset_path)
5353

5454
@staticmethod
55-
def _read_local_schema(dataset_path: str) -> SemanticLayerSchema:
56-
schema_path = os.path.join(
57-
find_project_root(), "datasets", dataset_path, "schema.yaml"
58-
)
59-
if not os.path.exists(schema_path):
55+
def _read_schema_file(dataset_path: str) -> SemanticLayerSchema:
56+
schema_path = os.path.join(dataset_path, "schema.yaml")
57+
58+
file_manager = ConfigManager.get().file_manager
59+
60+
if not file_manager.exists(schema_path):
6061
raise FileNotFoundError(f"Schema file not found: {schema_path}")
6162

62-
with open(schema_path, "r") as file:
63-
raw_schema = yaml.safe_load(file)
64-
raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"])
65-
return SemanticLayerSchema(**raw_schema)
63+
schema_file = file_manager.load(schema_path)
64+
raw_schema = yaml.safe_load(schema_file)
65+
raw_schema["name"] = sanitize_sql_table_name(raw_schema["name"])
66+
return SemanticLayerSchema(**raw_schema)
6667

6768
def load(self) -> DataFrame:
6869
"""
@@ -80,6 +81,3 @@ def _apply_transformations(self, df: pd.DataFrame) -> pd.DataFrame:
8081

8182
transformation_manager = TransformationManager(df)
8283
return transformation_manager.apply_transformations(self.schema.transformations)
83-
84-
def _get_abs_dataset_path(self):
85-
return os.path.join(find_project_root(), "datasets", self.dataset_path)

pandasai/data_loader/local_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _load_from_local_source(self) -> pd.DataFrame:
3737
)
3838

3939
filepath = os.path.join(
40-
str(self._get_abs_dataset_path()),
40+
self.dataset_path,
4141
self.schema.source.path,
4242
)
4343

pandasai/data_loader/sql_loader.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFra
4242
raise MaliciousQueryError(
4343
"The SQL query is deemed unsafe and will not be executed."
4444
)
45-
4645
try:
4746
dataframe: pd.DataFrame = load_function(
4847
connection_info, formatted_query, params

pandasai/dataframe/base.py

Lines changed: 24 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pandas._typing import Axes, Dtype
1111

1212
import pandasai as pai
13-
from pandasai.config import Config
13+
from pandasai.config import Config, ConfigManager
1414
from pandasai.core.response import BaseResponse
1515
from pandasai.data_loader.semantic_layer_schema import (
1616
Column,
@@ -19,7 +19,6 @@
1919
)
2020
from pandasai.exceptions import DatasetNotFound, PandaAIApiKeyError
2121
from pandasai.helpers.dataframe_serializer import DataframeSerializer
22-
from pandasai.helpers.path import find_project_root
2322
from pandasai.helpers.session import get_pandaai_session
2423

2524
if TYPE_CHECKING:
@@ -164,38 +163,32 @@ def push(self):
164163
"name": self.schema.name,
165164
}
166165

167-
dataset_directory = os.path.join(find_project_root(), "datasets", self.path)
168-
166+
dataset_directory = os.path.join("datasets", self.path)
167+
file_manager = ConfigManager.get().file_manager
169168
headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"}
170169

171170
files = []
172171
schema_file_path = os.path.join(dataset_directory, "schema.yaml")
173172
data_file_path = os.path.join(dataset_directory, "data.parquet")
174173

175-
try:
176-
# Open schema.yaml
177-
schema_file = open(schema_file_path, "rb")
178-
files.append(("files", ("schema.yaml", schema_file, "application/x-yaml")))
179-
180-
# Check if data.parquet exists and open it
181-
if os.path.exists(data_file_path):
182-
data_file = open(data_file_path, "rb")
183-
files.append(
184-
("files", ("data.parquet", data_file, "application/octet-stream"))
185-
)
186-
187-
# Send the POST request
188-
request_session.post(
189-
"/datasets/push",
190-
files=files,
191-
params=params,
192-
headers=headers,
174+
# Open schema.yaml
175+
schema_file = file_manager.load_binary(schema_file_path)
176+
files.append(("files", ("schema.yaml", schema_file, "application/x-yaml")))
177+
178+
# Check if data.parquet exists and open it
179+
if file_manager.exists(data_file_path):
180+
data_file = file_manager.load_binary(data_file_path)
181+
files.append(
182+
("files", ("data.parquet", data_file, "application/octet-stream"))
193183
)
194184

195-
finally:
196-
# Ensure files are closed after the request
197-
for _, (name, file, _) in files:
198-
file.close()
185+
# Send the POST request
186+
request_session.post(
187+
"/datasets/push",
188+
files=files,
189+
params=params,
190+
headers=headers,
191+
)
199192

200193
print("Your dataset was successfully pushed to the remote server!")
201194
print(f"🔗 URL: https://app.pandabi.ai/datasets/{self.path}")
@@ -218,20 +211,18 @@ def pull(self):
218211

219212
with ZipFile(BytesIO(file_data.content)) as zip_file:
220213
for file_name in zip_file.namelist():
221-
target_path = os.path.join(
222-
find_project_root(), "datasets", self.path, file_name
223-
)
214+
target_path = os.path.join(self.path, file_name)
224215

216+
file_manager = ConfigManager.get().file_manager
225217
# Check if the file already exists
226-
if os.path.exists(target_path):
218+
if file_manager.exists(target_path):
227219
print(f"Replacing existing file: {target_path}")
228220

229221
# Ensure target directory exists
230-
os.makedirs(os.path.dirname(target_path), exist_ok=True)
222+
file_manager.mkdir(os.path.dirname(target_path))
231223

232224
# Extract the file
233-
with open(target_path, "wb") as f:
234-
f.write(zip_file.read(file_name))
225+
file_manager.write_binary(target_path, zip_file.read(file_name))
235226

236227
# Reloads the Dataframe
237228
from pandasai import DatasetLoader

pandasai/helpers/filemanager.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import os
2+
from abc import ABC, abstractmethod
3+
4+
from pandasai.helpers.path import find_project_root
5+
6+
7+
class FileManager(ABC):
8+
"""Abstract base class for file loaders, supporting local and remote backends."""
9+
10+
@abstractmethod
11+
def load(self, file_path: str) -> str:
12+
"""Reads the content of a file."""
13+
pass
14+
15+
@abstractmethod
16+
def load_binary(self, file_path: str) -> bytes:
17+
"""Reads the content of a file as bytes."""
18+
pass
19+
20+
@abstractmethod
21+
def write(self, file_path: str, content: str) -> None:
22+
"""Writes content to a file."""
23+
pass
24+
25+
@abstractmethod
26+
def write_binary(self, file_path: str, content: bytes) -> None:
27+
"""Writes binary content to a file."""
28+
pass
29+
30+
@abstractmethod
31+
def exists(self, file_path: str) -> bool:
32+
"""Checks if a file or directory exists."""
33+
pass
34+
35+
@abstractmethod
36+
def mkdir(self, dir_path: str) -> None:
37+
"""Creates a directory if it doesn't exist."""
38+
pass
39+
40+
41+
class DefaultFileManager(FileManager):
42+
"""Local file system implementation of FileLoader."""
43+
44+
def __init__(self):
45+
self.base_path = os.path.join(find_project_root(), "datasets")
46+
47+
def load(self, file_path: str) -> str:
48+
full_path = os.path.join(self.base_path, file_path)
49+
with open(full_path, "r", encoding="utf-8") as f:
50+
return f.read()
51+
52+
def load_binary(self, file_path: str) -> bytes:
53+
full_path = os.path.join(self.base_path, file_path)
54+
with open(full_path, "rb") as f:
55+
return f.read()
56+
57+
def write(self, file_path: str, content: str) -> None:
58+
full_path = os.path.join(self.base_path, file_path)
59+
with open(full_path, "w", encoding="utf-8") as f:
60+
f.write(content)
61+
62+
def write_binary(self, file_path: str, content: bytes) -> None:
63+
full_path = os.path.join(self.base_path, file_path)
64+
with open(full_path, "wb") as f:
65+
f.write(content)
66+
67+
def exists(self, file_path: str) -> bool:
68+
full_path = os.path.join(self.base_path, file_path)
69+
return os.path.exists(full_path)
70+
71+
def mkdir(self, dir_path: str) -> None:
72+
full_path = os.path.join(self.base_path, dir_path)
73+
os.makedirs(full_path, exist_ok=True)

pandasai/helpers/path.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def find_project_root(filename=None):
1010

1111
# Get the path of the file that is be
1212
# ing executed
13+
1314
current_file_path = os.path.abspath(os.getcwd())
1415

1516
# Navigate back until we either find a $filename file or there is no parent

tests/unit_tests/agent/test_agent_chat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
import pytest
88

99
import pandasai as pai
10-
from pandasai import DataFrame, find_project_root
10+
from pandasai import DataFrame
1111
from pandasai.core.response import (
1212
ChartResponse,
1313
DataFrameResponse,
1414
NumberResponse,
1515
StringResponse,
1616
)
17+
from pandasai.helpers.filemanager import find_project_root
1718

1819
# Read the API key from an environment variable
1920
API_KEY = os.getenv("PANDABI_API_KEY_TEST_CHAT", None)

0 commit comments

Comments
 (0)