Skip to content

fixed pydantic schema #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions examples/integrations/indexify_node_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Basic example of scraping pipeline using SmartScraper with schema
"""

import os, json
from typing import List

from dotenv import load_dotenv
load_dotenv()

from pydantic import BaseModel, Field
from scrapegraphai.graphs import SmartScraperGraph
from scrapegraphai.integrations import IndexifyNode


# ************************************************
# Define the output schema for the graph
# ************************************************

class Image(BaseModel):
url: str = Field(description="The url of the image")

class Images(BaseModel):
images: List[Image]

# ************************************************
# Define the configuration for the graph
# ************************************************

openai_key = os.getenv("OPENAI_APIKEY")

graph_config = {
"llm": {
"api_key":openai_key,
"model": "gpt-3.5-turbo",
},
"verbose": True,
"headless": False,
}

# ************************************************
# Define the custom nodes for the graph
# ************************************************

indexify_node = IndexifyNode(
input="answer & img_urls",
output=["is_indexed"],
node_config={
"verbose": True
}
)

# ************************************************
# Create the SmartScraperGraph instance
# ************************************************

smart_scraper_graph = SmartScraperGraph(
prompt="List me all the images with their url",
source="https://giphy.com/",
schema=Images,
config=graph_config
)

# Add the custom node to the graph
smart_scraper_graph.append_node(indexify_node)

# ************************************************
# Run the SmartScraperGraph
# ************************************************

result = smart_scraper_graph.run()
print(json.dumps(result, indent=2))
8 changes: 0 additions & 8 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,6 @@ idna==3.7
# via yarl
imagesize==1.4.1
# via sphinx
importlib-metadata==7.1.0
# via sphinx
importlib-resources==6.4.0
# via matplotlib
iniconfig==2.0.0
# via pytest
jinja2==3.1.4
Expand Down Expand Up @@ -475,7 +471,6 @@ typing-extensions==4.12.0
# via pyee
# via sf-hamilton
# via sqlalchemy
# via starlette
# via streamlit
# via typer
# via typing-inspect
Expand Down Expand Up @@ -507,6 +502,3 @@ win32-setctime==1.1.0
# via loguru
yarl==1.9.4
# via aiohttp
zipp==3.19.1
# via importlib-metadata
# via importlib-resources
3 changes: 2 additions & 1 deletion scrapegraphai/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
Init file for integrations module
"""

from .burr_bridge import BurrBridge
from .burr_bridge import BurrBridge
from .indexify_node import IndexifyNode
79 changes: 79 additions & 0 deletions scrapegraphai/integrations/indexify_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
IndexifyNode Module
"""

from typing import List, Optional

from ..utils.logging import get_logger
from ..nodes.base_node import BaseNode

# try:
# import indexify
# except ImportError:
# raise ImportError("indexify package is not installed. Please install it with 'pip install scrapegraphai[indexify]'")


class IndexifyNode(BaseNode):
"""
A node responsible for indexing the content present in the state.

Attributes:
verbose (bool): A flag indicating whether to show print statements during execution.

Args:
input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of output keys to be updated in the state.
node_config (dict): Additional configuration for the node.
node_name (str): The unique identifier name for the node, defaulting to "Parse".
"""

def __init__(
self,
input: str,
output: List[str],
node_config: Optional[dict] = None,
node_name: str = "Indexify",
):
super().__init__(node_name, "node", input, output, 2, node_config)

self.verbose = (
False if node_config is None else node_config.get("verbose", False)
)

def execute(self, state: dict) -> dict:
"""
Executes the node's logic to index the content present in the state.

Args:
state (dict): The current state of the graph. The input keys will be used to fetch the
correct data from the state.

Returns:
dict: The updated state with the output key containing the parsed content chunks.

Raises:
KeyError: If the input keys are not found in the state, indicating that the
necessary information for parsing the content is missing.
"""

self.logger.info(f"--- Executing {self.node_name} Node ---")

# Interpret input keys based on the provided input expression
# input_keys length matches the min_input_len parameter in the __init__ method
# e.g. "answer & parsed_doc" or "answer | img_urls"

input_keys = self.get_input_keys(state)

# Fetching data from the state based on the input keys
input_data = [state[key] for key in input_keys]

answer = input_data[0]
img_urls = input_data[1]

# Indexify the content
# ...

isIndexified = True
state.update({self.output[0]: isIndexified})

return state
7 changes: 2 additions & 5 deletions scrapegraphai/nodes/generate_answer_csv_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from tqdm import tqdm

Expand Down Expand Up @@ -96,7 +96,7 @@ def execute(self, state):

# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = PydanticOutputParser(pydantic_object=self.node_config.get("schema", None))
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
output_parser = JsonOutputParser()

Expand Down Expand Up @@ -150,9 +150,6 @@ def execute(self, state):
single_chain = list(chains_dict.values())[0]
answer = single_chain.invoke({"question": user_prompt})

if type(answer) == PydanticOutputParser:
answer = answer.model_dump()

# Update the state with the generated answer
state.update({self.output[0]: answer})
return state
10 changes: 4 additions & 6 deletions scrapegraphai/nodes/generate_answer_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from tqdm import tqdm


from ..utils.logging import get_logger
from ..models import Ollama
# Imports from the library
Expand Down Expand Up @@ -81,8 +82,8 @@ def execute(self, state: dict) -> dict:
doc = input_data[1]

# Initialize the output parser
if self.node_config.get("schema",None) is not None:
output_parser = PydanticOutputParser(pydantic_object=self.node_config.get("schema", None))
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
output_parser = JsonOutputParser()

Expand Down Expand Up @@ -129,9 +130,6 @@ def execute(self, state: dict) -> dict:
single_chain = list(chains_dict.values())[0]
answer = single_chain.invoke({"question": user_prompt})

if type(answer) == PydanticOutputParser:
answer = answer.model_dump()

# Update the state with the generated answer
state.update({self.output[0]: answer})
return state
7 changes: 2 additions & 5 deletions scrapegraphai/nodes/generate_answer_omni_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from tqdm import tqdm
from ..models import Ollama
Expand Down Expand Up @@ -82,7 +82,7 @@ def execute(self, state: dict) -> dict:

# Initialize the output parser
if self.node_config.get("schema", None) is not None:
output_parser = PydanticOutputParser(pydantic_object=self.node_config.get("schema", None))
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
output_parser = JsonOutputParser()

Expand Down Expand Up @@ -141,9 +141,6 @@ def execute(self, state: dict) -> dict:
single_chain = list(chains_dict.values())[0]
answer = single_chain.invoke({"question": user_prompt})

if type(answer) == PydanticOutputParser:
answer = answer.model_dump()

# Update the state with the generated answer
state.update({self.output[0]: answer})
return state
6 changes: 3 additions & 3 deletions scrapegraphai/nodes/generate_answer_pdf_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
from tqdm import tqdm
from ..models import Ollama
Expand Down Expand Up @@ -96,8 +96,8 @@ def execute(self, state):
doc = input_data[1]

# Initialize the output parser
if self.node_config.get("schema",None) is not None:
output_parser = PydanticOutputParser(pydantic_object=self.node_config.get("schema", None))
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
output_parser = JsonOutputParser()

Expand Down
11 changes: 3 additions & 8 deletions scrapegraphai/nodes/merge_answers_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers import JsonOutputParser
from tqdm import tqdm

from ..utils.logging import get_logger
Expand Down Expand Up @@ -80,10 +80,8 @@ def execute(self, state: dict) -> dict:
answers_str += f"CONTENT WEBSITE {i+1}: {answer}\n"

# Initialize the output parser
if self.node_config["schema"] is not None:
output_parser = PydanticOutputParser(
pydantic_object=self.node_config["schema"]
)
if self.node_config.get("schema", None) is not None:
output_parser = JsonOutputParser(pydantic_object=self.node_config["schema"])
else:
output_parser = JsonOutputParser()

Expand Down Expand Up @@ -111,9 +109,6 @@ def execute(self, state: dict) -> dict:
merge_chain = prompt_template | self.llm_model | output_parser
answer = merge_chain.invoke({"user_prompt": user_prompt})

if type(answer) == PydanticOutputParser:
answer = answer.model_dump()

# Update the state with the generated answer
state.update({self.output[0]: answer})
return state
Loading