diff --git a/.python-version b/.python-version
index 8e34c813..1445aee8 100644
--- a/.python-version
+++ b/.python-version
@@ -1 +1 @@
-3.9.19
+3.10.14
diff --git a/examples/local_models/smart_scraper_ollama.py b/examples/local_models/smart_scraper_ollama.py
index babf4c2b..8c17ffa6 100644
--- a/examples/local_models/smart_scraper_ollama.py
+++ b/examples/local_models/smart_scraper_ollama.py
@@ -20,6 +20,7 @@
# "base_url": "http://localhost:11434", # set ollama URL arbitrarily
},
"verbose": True,
+ "headless": False
}
# ************************************************
diff --git a/examples/openai/smart_scraper_openai.py b/examples/openai/smart_scraper_openai.py
index 4f0952ae..ed10b409 100644
--- a/examples/openai/smart_scraper_openai.py
+++ b/examples/openai/smart_scraper_openai.py
@@ -19,9 +19,9 @@
graph_config = {
"llm": {
"api_key": openai_key,
- "model": "gpt-4o",
+ "model": "gpt-3.5-turbo",
},
- "verbose": True,
+ "verbose": False,
"headless": False,
}
diff --git a/examples/single_node/robot_node.py b/examples/single_node/robot_node.py
index 257c4efb..d824400a 100644
--- a/examples/single_node/robot_node.py
+++ b/examples/single_node/robot_node.py
@@ -11,7 +11,7 @@
graph_config = {
"llm": {
- "model": "ollama/llama3",
+ "model_name": "ollama/llama3",
"temperature": 0,
"streaming": True
},
diff --git a/requirements-dev.lock b/requirements-dev.lock
index 7c37321b..02ba2fde 100644
--- a/requirements-dev.lock
+++ b/requirements-dev.lock
@@ -93,7 +93,6 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
- # via sqlalchemy
groq==0.5.0
# via langchain-groq
grpcio==1.63.0
diff --git a/requirements.lock b/requirements.lock
index c02d4522..09d427cc 100644
--- a/requirements.lock
+++ b/requirements.lock
@@ -92,7 +92,6 @@ graphviz==0.20.3
# via scrapegraphai
greenlet==3.0.3
# via playwright
- # via sqlalchemy
groq==0.5.0
# via langchain-groq
grpcio==1.63.0
diff --git a/scrapegraphai/asdt/__init__.py b/scrapegraphai/asdt/__init__.py
deleted file mode 100644
index 539534d6..00000000
--- a/scrapegraphai/asdt/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-"""
- __init__.py file for asdt module.
-"""
-
-from .dom_tree import DOMTree
diff --git a/scrapegraphai/asdt/dom_tree.py b/scrapegraphai/asdt/dom_tree.py
deleted file mode 100644
index 50b2e179..00000000
--- a/scrapegraphai/asdt/dom_tree.py
+++ /dev/null
@@ -1,52 +0,0 @@
-from bs4 import BeautifulSoup, Comment, NavigableString, Tag
-from .tree import Tree
-from .tree_node import TreeNode
-
-class DOMTree(Tree):
- def __init__(self, html_content):
- super().__init__()
- self.root = TreeNode('document')
- self.build_dom_tree(BeautifulSoup(html_content, 'html.parser'), self.root)
-
- def build_dom_tree(self, soup_node, tree_node):
- for child in soup_node.children:
- if isinstance(child, Comment):
- continue # Skip comments
- elif isinstance(child, NavigableString):
- text = child.strip()
- if text:
- new_node = TreeNode(value='text', attributes={'content': text})
- tree_node.add_child(new_node)
- new_node.finalize_node()
- elif isinstance(child, Tag):
- new_node = TreeNode(value=child.name, attributes=child.attrs)
- tree_node.add_child(new_node)
- self.build_dom_tree(child, new_node)
-
- def collect_text_nodes(self, exclude_script=True):
- texts = []
- metadatas = []
-
- def collect(node):
- # If node is a text node, collect its data
- if node.value == 'text':
- texts.append(node.attributes['content'])
- metadatas.append({
- 'root_path': node.root_path,
- 'closest_fork_path': node.closest_fork_path
- })
-
- # Traverse the DOM tree to collect text nodes and their metadata
- def traverse_for_text(node):
- # Skip traversal into script tags, but continue for other nodes
- if exclude_script and node.value == 'script':
- return # Skip script tags
-
- if node.leads_to_text or node.value == 'text':
- collect(node)
- for child in node.children:
- traverse_for_text(child)
-
- traverse_for_text(self.root)
- return texts, metadatas
-
diff --git a/scrapegraphai/asdt/tree.py b/scrapegraphai/asdt/tree.py
deleted file mode 100644
index be95f8e6..00000000
--- a/scrapegraphai/asdt/tree.py
+++ /dev/null
@@ -1,98 +0,0 @@
-from graphviz import Digraph
-
-class Tree:
- def __init__(self, root=None):
- self.root = root
-
- def traverse(self, visit_func):
- def _traverse(node):
- if node:
- visit_func(node)
- for child in node.children:
- _traverse(child)
- _traverse(self.root)
-
- def get_subtrees(self):
- # Retrieves all subtrees rooted at fork nodes
- return self.root.get_subtrees() if self.root else []
-
- def generate_subtree_dicts(self):
- subtree_dicts = []
-
- def aggregate_text_under_fork(fork_node):
- text_aggregate = {
- "content": [],
- "path_to_fork": ""
- }
- for child in fork_node.children:
- if child.value == 'text':
- text_aggregate["content"].append(child.attributes['content'])
- elif child.is_fork:
- continue
- else:
- for sub_child in child.children:
- text_aggregate["content"].append(sub_child.attributes)
-
- text_aggregate["path_to_fork"] = fork_node.closest_fork_path
- return text_aggregate
-
- def process_node(node):
- if node.is_fork:
- texts = aggregate_text_under_fork(node)
- if texts["content"]: # Only add if there's text content
- subtree_dicts.append({
- node.value: {
- "text": texts,
- "path_to_fork": texts["path_to_fork"],
- }
- })
- for child in node.children:
- process_node(child)
-
- process_node(self.root)
- return subtree_dicts
-
- def visualize(self, exclude_tags = ['script']):
- def add_nodes_edges(tree_node, graph):
- if tree_node:
- # Skip excluded tags
- if tree_node.value in exclude_tags:
- return
-
- # Format node label to include attributes
- attr_str = None
- label = f"{tree_node.value}\n[{attr_str}]" if attr_str else tree_node.value
- # Determine color based on node properties
- if tree_node.value == 'text':
- color = 'red' # Text nodes
- elif tree_node.is_fork:
- color = 'green' # Fork nodes
- elif tree_node.leads_to_text:
- color = 'lightblue2' # Nodes leading to text
- else:
- color = 'white' # Nodes that do not lead to text and are not forks
-
- # Customize node appearance
- graph.node(name=str(id(tree_node)), label=label,
- fontsize='12', shape='ellipse', color=color, fontcolor='black')
-
- if tree_node.parent:
- graph.edge(str(id(tree_node.parent)), str(id(tree_node)), fontsize='10')
-
- for child in tree_node.children:
- add_nodes_edges(child, graph)
-
-
- # Initialize Digraph, set graph and node attributes
- graph = Digraph()
- # graph.attr(size='10,10', dpi='300') # Set higher DPI for better image resolution
- graph.attr('node', style='filled', fontname='Helvetica')
- graph.attr('edge', fontname='Helvetica')
-
- add_nodes_edges(self.root, graph)
- graph.render('tree_visualization', view=True, format='svg') # Change format to SVG for vectorized output
-
- return graph
-
- def __repr__(self):
- return f"Tree(root={self.root})"
\ No newline at end of file
diff --git a/scrapegraphai/asdt/tree_node.py b/scrapegraphai/asdt/tree_node.py
deleted file mode 100644
index 636cb5c1..00000000
--- a/scrapegraphai/asdt/tree_node.py
+++ /dev/null
@@ -1,114 +0,0 @@
-from .tree import Tree
-
-class TreeNode:
- def __init__(self, value=None, attributes=None, children=None, parent=None, depth=0):
- self.value = value
- self.attributes = attributes if attributes is not None else {}
- self.children = children if children is not None else []
- self.parent = parent
- self.depth = depth
- # Flag to track if the subtree leads to text
- self.leads_to_text = False
- # Flags to track if the subtree has a direct leaf node
- self.has_direct_leaves = False
- self.root_path = self._compute_root_path()
- self.closest_fork_path = self._compute_fork_path()
- self.structure_hash = None
- self.content_hash = None
-
- def add_child(self, child_node):
- child_node.parent = self
- child_node.depth = self.depth + 1
- self.children.append(child_node)
- child_node.update_paths()
- self.update_leads_to_text()
- self.update_hashes() # Update hashes when the structure changes
-
- def update_hashes(self):
- self.structure_hash = self.hash_subtree_structure(self)
- self.content_hash = self.hash_subtree_content(self)
-
- def update_paths(self):
- self.root_path = self._compute_root_path()
- self.closest_fork_path = self._compute_fork_path()
-
- def update_leads_to_text(self):
- # Check if any child leads to text or is a text node
- if any(child.value == 'text' or child.leads_to_text for child in self.children):
- self.leads_to_text = True
- # Update the flag up the tree
- if self.parent and not self.parent.leads_to_text:
- self.parent.update_leads_to_text()
-
- def _compute_root_path(self):
- path = []
- current = self
- while current.parent:
- path.append(current.value)
- current = current.parent
- path.append('root') # Append 'root' to start of the path
- return '>'.join(reversed(path))
-
- def _compute_fork_path(self):
- path = []
- current = self
- while current.parent and len(current.parent.children) == 1:
- path.append(current.value)
- current = current.parent
- path.append(current.value) # Add the fork or root node
- return '>'.join(reversed(path))
-
- def finalize_node(self):
- if self.is_text and self.is_leaf:
- self.update_direct_leaves_flag()
-
- def update_direct_leaves_flag(self):
- ancestor = self.parent
- while ancestor and len(ancestor.children) == 1:
- ancestor = ancestor.parent
- if ancestor and ancestor.is_fork:
- ancestor.has_direct_leaves = True
-
- def get_subtrees(self, direct_leaves=False):
- # This method finds and returns subtrees rooted at this node and all descendant forks
- # Optionally filters to include only those with direct leaves beneath fork nodes
- subtrees = []
- if self.is_fork and (not direct_leaves or self.has_direct_leaves):
- subtrees.append(Tree(root=self))
- for child in self.children:
- subtrees.extend(child.get_subtrees(direct_leaves=direct_leaves))
- return subtrees
-
- def hash_subtree_structure(self, node):
- """ Recursively generate a hash for the subtree structure. """
- if node.is_leaf:
- return hash((node.value,)) # Simple hash for leaf nodes
- child_hashes = tuple(self.hash_subtree_structure(child) for child in node.children)
- return hash((node.value, child_hashes))
-
- def hash_subtree_content(self, node):
- """ Generate a hash based on the concatenated text of the subtree. """
- text_content = self.get_all_text(node).lower().strip()
- return hash(text_content)
-
- def get_all_text(self, node):
- """ Recursively get all text from a node and its descendants. """
- text = node.attributes.get('content', '') if node.value == 'text' else ''
- for child in node.children:
- text += self.get_all_text(child)
- return text
-
- def __repr__(self):
- return f"TreeNode(value={self.value}, leads_to_text={self.leads_to_text}, is_fork={self.is_fork})"
-
- @property
- def is_fork(self):
- return len(self.children) > 1
-
- @property
- def is_leaf(self):
- return len(self.children) == 0
-
- @property
- def is_text(self):
- return self.value == 'text'
\ No newline at end of file
diff --git a/scrapegraphai/docloaders/chromium.py b/scrapegraphai/docloaders/chromium.py
index 7d499245..64a74734 100644
--- a/scrapegraphai/docloaders/chromium.py
+++ b/scrapegraphai/docloaders/chromium.py
@@ -1,14 +1,13 @@
import asyncio
-import logging
from typing import Any, AsyncIterator, Iterator, List, Optional
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document
-from ..utils import Proxy, dynamic_import, parse_or_search_proxy
+from ..utils import Proxy, dynamic_import, get_logger, parse_or_search_proxy
-logger = logging.getLogger(__name__)
+logger = get_logger("web-loader")
class ChromiumLoader(BaseLoader):
diff --git a/scrapegraphai/graphs/__init__.py b/scrapegraphai/graphs/__init__.py
index 10eb6d8e..15f4a4ec 100644
--- a/scrapegraphai/graphs/__init__.py
+++ b/scrapegraphai/graphs/__init__.py
@@ -15,4 +15,3 @@
from .pdf_scraper_graph import PDFScraperGraph
from .omni_scraper_graph import OmniScraperGraph
from .omni_search_graph import OmniSearchGraph
-from .turbo_scraper import TurboScraperGraph
diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py
index 28eb27b2..839af910 100644
--- a/scrapegraphai/graphs/abstract_graph.py
+++ b/scrapegraphai/graphs/abstract_graph.py
@@ -1,15 +1,28 @@
"""
AbstractGraph Module
"""
+
from abc import ABC, abstractmethod
from typing import Optional
+
from langchain_aws import BedrockEmbeddings
-from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_community.embeddings import HuggingFaceHubEmbeddings, OllamaEmbeddings
from langchain_google_genai import GoogleGenerativeAIEmbeddings
-from ..helpers import models_tokens
-from ..models import AzureOpenAI, Bedrock, Gemini, Groq, HuggingFace, Ollama, OpenAI, Anthropic
from langchain_google_genai.embeddings import GoogleGenerativeAIEmbeddings
+from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings
+
+from ..helpers import models_tokens
+from ..models import (
+ Anthropic,
+ AzureOpenAI,
+ Bedrock,
+ Gemini,
+ Groq,
+ HuggingFace,
+ Ollama,
+ OpenAI,
+)
+from ..utils.logging import set_verbosity_debug, set_verbosity_warning
class AbstractGraph(ABC):
@@ -46,9 +59,11 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.source = source
self.config = config
self.llm_model = self._create_llm(config["llm"], chat=True)
- self.embedder_model = self._create_default_embedder(llm_config=config["llm"]
- ) if "embeddings" not in config else self._create_embedder(
- config["embeddings"])
+ self.embedder_model = (
+ self._create_default_embedder(llm_config=config["llm"])
+ if "embeddings" not in config
+ else self._create_embedder(config["embeddings"])
+ )
# Create the graph
self.graph = self._create_graph()
@@ -56,17 +71,23 @@ def __init__(self, prompt: str, config: dict, source: Optional[str] = None):
self.execution_info = None
# Set common configuration parameters
- self.verbose = False if config is None else config.get(
- "verbose", False)
- self.headless = True if config is None else config.get(
- "headless", True)
+
+ verbose = bool(config and config.get("verbose"))
+
+ if verbose:
+ set_verbosity_debug()
+ else:
+ set_verbosity_warning()
+
+ self.headless = True if config is None else config.get("headless", True)
self.loader_kwargs = config.get("loader_kwargs", {})
- common_params = {"headless": self.headless,
- "verbose": self.verbose,
- "loader_kwargs": self.loader_kwargs,
- "llm_model": self.llm_model,
- "embedder_model": self.embedder_model}
+ common_params = {
+ "headless": self.headless,
+ "loader_kwargs": self.loader_kwargs,
+ "llm_model": self.llm_model,
+ "embedder_model": self.embedder_model,
+ }
self.set_common_params(common_params, overwrite=False)
def set_common_params(self, params: dict, overwrite=False):
@@ -82,22 +103,22 @@ def set_common_params(self, params: dict, overwrite=False):
def _set_model_token(self, llm):
- if 'Azure' in str(type(llm)):
+ if "Azure" in str(type(llm)):
try:
self.model_token = models_tokens["azure"][llm.model_name]
except KeyError:
raise KeyError("Model not supported")
- elif 'HuggingFaceEndpoint' in str(type(llm)):
- if 'mistral' in llm.repo_id:
+ elif "HuggingFaceEndpoint" in str(type(llm)):
+ if "mistral" in llm.repo_id:
try:
- self.model_token = models_tokens['mistral'][llm.repo_id]
+ self.model_token = models_tokens["mistral"][llm.repo_id]
except KeyError:
raise KeyError("Model not supported")
- elif 'Google' in str(type(llm)):
+ elif "Google" in str(type(llm)):
try:
- if 'gemini' in llm.model:
- self.model_token = models_tokens['gemini'][llm.model]
+ if "gemini" in llm.model:
+ self.model_token = models_tokens["gemini"][llm.model]
except KeyError:
raise KeyError("Model not supported")
@@ -115,17 +136,14 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
KeyError: If the model is not supported.
"""
- llm_defaults = {
- "temperature": 0,
- "streaming": False
- }
+ llm_defaults = {"temperature": 0, "streaming": False}
llm_params = {**llm_defaults, **llm_config}
# If model instance is passed directly instead of the model details
- if 'model_instance' in llm_params:
+ if "model_instance" in llm_params:
if chat:
- self._set_model_token(llm_params['model_instance'])
- return llm_params['model_instance']
+ self._set_model_token(llm_params["model_instance"])
+ return llm_params["model_instance"]
# Instantiate the language model based on the model name
if "gpt-" in llm_params["model"]:
@@ -191,18 +209,20 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
elif "bedrock" in llm_params["model"]:
llm_params["model"] = llm_params["model"].split("/")[-1]
model_id = llm_params["model"]
- client = llm_params.get('client', None)
+ client = llm_params.get("client", None)
try:
self.model_token = models_tokens["bedrock"][llm_params["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
- return Bedrock({
- "client": client,
- "model_id": model_id,
- "model_kwargs": {
- "temperature": llm_params["temperature"],
+ return Bedrock(
+ {
+ "client": client,
+ "model_id": model_id,
+ "model_kwargs": {
+ "temperature": llm_params["temperature"],
+ },
}
- })
+ )
elif "claude-3-" in llm_params["model"]:
self.model_token = models_tokens["claude"]["claude3"]
return Anthropic(llm_params)
@@ -213,8 +233,7 @@ def _create_llm(self, llm_config: dict, chat=False) -> object:
raise KeyError("Model not supported") from exc
return DeepSeek(llm_params)
else:
- raise ValueError(
- "Model provided by the configuration not supported")
+ raise ValueError("Model provided by the configuration not supported")
def _create_default_embedder(self, llm_config=None) -> object:
"""
@@ -227,8 +246,9 @@ def _create_default_embedder(self, llm_config=None) -> object:
ValueError: If the model is not supported.
"""
if isinstance(self.llm_model, Gemini):
- return GoogleGenerativeAIEmbeddings(google_api_key=llm_config['api_key'],
- model="models/embedding-001")
+ return GoogleGenerativeAIEmbeddings(
+ google_api_key=llm_config["api_key"], model="models/embedding-001"
+ )
if isinstance(self.llm_model, OpenAI):
return OpenAIEmbeddings(api_key=self.llm_model.openai_api_key)
elif isinstance(self.llm_model, AzureOpenAIEmbeddings):
@@ -263,8 +283,8 @@ def _create_embedder(self, embedder_config: dict) -> object:
Raises:
KeyError: If the model is not supported.
"""
- if 'model_instance' in embedder_config:
- return embedder_config['model_instance']
+ if "model_instance" in embedder_config:
+ return embedder_config["model_instance"]
# Instantiate the embedding model based on the model name
if "openai" in embedder_config["model"]:
return OpenAIEmbeddings(api_key=embedder_config["api_key"])
@@ -281,28 +301,27 @@ def _create_embedder(self, embedder_config: dict) -> object:
try:
models_tokens["hugging_face"][embedder_config["model"]]
except KeyError as exc:
- raise KeyError("Model not supported")from exc
+ raise KeyError("Model not supported") from exc
return HuggingFaceHubEmbeddings(model=embedder_config["model"])
elif "gemini" in embedder_config["model"]:
try:
models_tokens["gemini"][embedder_config["model"]]
except KeyError as exc:
- raise KeyError("Model not supported")from exc
+ raise KeyError("Model not supported") from exc
return GoogleGenerativeAIEmbeddings(model=embedder_config["model"])
elif "bedrock" in embedder_config["model"]:
embedder_config["model"] = embedder_config["model"].split("/")[-1]
- client = embedder_config.get('client', None)
+ client = embedder_config.get("client", None)
try:
models_tokens["bedrock"][embedder_config["model"]]
except KeyError as exc:
raise KeyError("Model not supported") from exc
return BedrockEmbeddings(client=client, model_id=embedder_config["model"])
else:
- raise ValueError(
- "Model provided by the configuration not supported")
+ raise ValueError("Model provided by the configuration not supported")
def get_state(self, key=None) -> dict:
- """""
+ """ ""
Get the final state of the graph.
Args:
diff --git a/scrapegraphai/helpers/models_tokens.py b/scrapegraphai/helpers/models_tokens.py
index f8881d75..871291f5 100644
--- a/scrapegraphai/helpers/models_tokens.py
+++ b/scrapegraphai/helpers/models_tokens.py
@@ -5,6 +5,7 @@
models_tokens = {
"openai": {
"gpt-3.5-turbo-0125": 16385,
+ "gpt-3.5": 4096,
"gpt-3.5-turbo": 4096,
"gpt-3.5-turbo-1106": 16385,
"gpt-3.5-turbo-instruct": 4096,
diff --git a/scrapegraphai/nodes/base_node.py b/scrapegraphai/nodes/base_node.py
index cabfeda0..60f4c946 100644
--- a/scrapegraphai/nodes/base_node.py
+++ b/scrapegraphai/nodes/base_node.py
@@ -2,9 +2,11 @@
BaseNode Module
"""
-from abc import ABC, abstractmethod
-from typing import Optional, List
import re
+from abc import ABC, abstractmethod
+from typing import List, Optional
+
+from ..utils import get_logger
class BaseNode(ABC):
@@ -14,10 +16,11 @@ class BaseNode(ABC):
Attributes:
node_name (str): The unique identifier name for the node.
input (str): Boolean expression defining the input keys needed from the state.
- output (List[str]): List of
+ output (List[str]): List of
min_input_len (int): Minimum required number of input keys.
node_config (Optional[dict]): Additional configuration for the node.
-
+ logger (logging.Logger): The centralized root logger
+
Args:
node_name (str): Name for identifying the node.
node_type (str): Type of the node; must be 'node' or 'conditional_node'.
@@ -28,7 +31,7 @@ class BaseNode(ABC):
Raises:
ValueError: If `node_type` is not one of the allowed types.
-
+
Example:
>>> class MyNode(BaseNode):
... def execute(self, state):
@@ -40,18 +43,27 @@ class BaseNode(ABC):
{'key': 'value'}
"""
- def __init__(self, node_name: str, node_type: str, input: str, output: List[str],
- min_input_len: int = 1, node_config: Optional[dict] = None):
+ def __init__(
+ self,
+ node_name: str,
+ node_type: str,
+ input: str,
+ output: List[str],
+ min_input_len: int = 1,
+ node_config: Optional[dict] = None,
+ ):
self.node_name = node_name
self.input = input
self.output = output
self.min_input_len = min_input_len
self.node_config = node_config
+ self.logger = get_logger()
if node_type not in ["node", "conditional_node"]:
raise ValueError(
- f"node_type must be 'node' or 'conditional_node', got '{node_type}'")
+ f"node_type must be 'node' or 'conditional_node', got '{node_type}'"
+ )
self.node_type = node_type
@abstractmethod
@@ -102,8 +114,7 @@ def get_input_keys(self, state: dict) -> List[str]:
self._validate_input_keys(input_keys)
return input_keys
except ValueError as e:
- raise ValueError(
- f"Error parsing input keys for {self.node_name}: {str(e)}")
+ raise ValueError(f"Error parsing input keys for {self.node_name}: {str(e)}")
def _validate_input_keys(self, input_keys):
"""
@@ -119,7 +130,8 @@ def _validate_input_keys(self, input_keys):
if len(input_keys) < self.min_input_len:
raise ValueError(
f"""{self.node_name} requires at least {self.min_input_len} input keys,
- got {len(input_keys)}.""")
+ got {len(input_keys)}."""
+ )
def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
"""
@@ -142,67 +154,80 @@ def _parse_input_keys(self, state: dict, expression: str) -> List[str]:
raise ValueError("Empty expression.")
# Check for adjacent state keys without an operator between them
- pattern = r'\b(' + '|'.join(re.escape(key) for key in state.keys()) + \
- r')(\b\s*\b)(' + '|'.join(re.escape(key)
- for key in state.keys()) + r')\b'
+ pattern = (
+ r"\b("
+ + "|".join(re.escape(key) for key in state.keys())
+ + r")(\b\s*\b)("
+ + "|".join(re.escape(key) for key in state.keys())
+ + r")\b"
+ )
if re.search(pattern, expression):
raise ValueError(
- "Adjacent state keys found without an operator between them.")
+ "Adjacent state keys found without an operator between them."
+ )
# Remove spaces
expression = expression.replace(" ", "")
# Check for operators with empty adjacent tokens or at the start/end
- if expression[0] in '&|' or expression[-1] in '&|' \
- or '&&' in expression or '||' in expression or \
- '&|' in expression or '|&' in expression:
+ if (
+ expression[0] in "&|"
+ or expression[-1] in "&|"
+ or "&&" in expression
+ or "||" in expression
+ or "&|" in expression
+ or "|&" in expression
+ ):
raise ValueError("Invalid operator usage.")
# Check for balanced parentheses and valid operator placement
open_parentheses = close_parentheses = 0
for i, char in enumerate(expression):
- if char == '(':
+ if char == "(":
open_parentheses += 1
- elif char == ')':
+ elif char == ")":
close_parentheses += 1
# Check for invalid operator sequences
if char in "&|" and i + 1 < len(expression) and expression[i + 1] in "&|":
raise ValueError(
- "Invalid operator placement: operators cannot be adjacent.")
+ "Invalid operator placement: operators cannot be adjacent."
+ )
# Check for missing or balanced parentheses
if open_parentheses != close_parentheses:
- raise ValueError(
- "Missing or unbalanced parentheses in expression.")
+ raise ValueError("Missing or unbalanced parentheses in expression.")
# Helper function to evaluate an expression without parentheses
def evaluate_simple_expression(exp: str) -> List[str]:
"""Evaluate an expression without parentheses."""
# Split the expression by the OR operator and process each segment
- for or_segment in exp.split('|'):
+ for or_segment in exp.split("|"):
# Check if all elements in an AND segment are in state
- and_segment = or_segment.split('&')
+ and_segment = or_segment.split("&")
if all(elem.strip() in state for elem in and_segment):
- return [elem.strip() for elem in and_segment if elem.strip() in state]
+ return [
+ elem.strip() for elem in and_segment if elem.strip() in state
+ ]
return []
# Helper function to evaluate expressions with parentheses
def evaluate_expression(expression: str) -> List[str]:
"""Evaluate an expression with parentheses."""
-
- while '(' in expression:
- start = expression.rfind('(')
- end = expression.find(')', start)
- sub_exp = expression[start + 1:end]
+
+ while "(" in expression:
+ start = expression.rfind("(")
+ end = expression.find(")", start)
+ sub_exp = expression[start + 1 : end]
# Replace the evaluated part with a placeholder and then evaluate it
sub_result = evaluate_simple_expression(sub_exp)
# For simplicity in handling, join sub-results with OR to reprocess them later
- expression = expression[:start] + \
- '|'.join(sub_result) + expression[end+1:]
+ expression = (
+ expression[:start] + "|".join(sub_result) + expression[end + 1 :]
+ )
return evaluate_simple_expression(expression)
result = evaluate_expression(expression)
diff --git a/scrapegraphai/nodes/blocks_identifier.py b/scrapegraphai/nodes/blocks_identifier.py
index 70fd09a7..d06c9805 100644
--- a/scrapegraphai/nodes/blocks_identifier.py
+++ b/scrapegraphai/nodes/blocks_identifier.py
@@ -3,21 +3,22 @@
"""
from typing import List, Optional
+
from langchain_community.document_loaders import AsyncChromiumLoader
from langchain_core.documents import Document
-from .base_node import BaseNode
+from .base_node import BaseNode
class BlocksIndentifier(BaseNode):
"""
A node responsible to identify the blocks in the HTML content of a specified HTML content
- e.g products in a E-commerce, flights in a travel website etc.
+ e.g products in a E-commerce, flights in a travel website etc.
Attributes:
headless (bool): A flag indicating whether the browser should run in headless mode.
verbose (bool): A flag indicating whether to print verbose output during execution.
-
+
Args:
input (str): Boolean expression defining the input keys needed from the state.
output (List[str]): List of output keys to be updated in the state.
@@ -25,11 +26,21 @@ class BlocksIndentifier(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "BlocksIndentifier".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict], node_name: str = "BlocksIndentifier"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict],
+ node_name: str = "BlocksIndentifier",
+ ):
super().__init__(node_name, "node", input, output, 1)
- self.headless = True if node_config is None else node_config.get("headless", True)
- self.verbose = True if node_config is None else node_config.get("verbose", False)
+ self.headless = (
+ True if node_config is None else node_config.get("headless", True)
+ )
+ self.verbose = (
+ True if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state):
"""
@@ -47,8 +58,7 @@ def execute(self, state):
KeyError: If the input key is not found in the state, indicating that the
necessary information to perform the operation is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
diff --git a/scrapegraphai/nodes/fetch_node.py b/scrapegraphai/nodes/fetch_node.py
index 0bfb0111..d3609e2e 100644
--- a/scrapegraphai/nodes/fetch_node.py
+++ b/scrapegraphai/nodes/fetch_node.py
@@ -1,18 +1,19 @@
-"""
+""""
FetchNode Module
"""
import json
-import requests
from typing import List, Optional
import pandas as pd
+import requests
from langchain_community.document_loaders import PyPDFLoader
from langchain_core.documents import Document
from ..docloaders import ChromiumLoader
-from .base_node import BaseNode
from ..utils.cleanup_html import cleanup_html
+from ..utils.logging import get_logger
+from .base_node import BaseNode
class FetchNode(BaseNode):
@@ -51,7 +52,7 @@ def __init__(
False if node_config is None else node_config.get("verbose", False)
)
self.useSoup = (
- False if node_config is None else node_config.get("useSoup", False)
+ False if node_config is None else node_config.get("useSoup", False)
)
self.loader_kwargs = (
{} if node_config is None else node_config.get("loader_kwargs", {})
@@ -73,8 +74,8 @@ def execute(self, state):
KeyError: If the input key is not found in the state, indicating that the
necessary information to perform the operation is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -92,7 +93,7 @@ def execute(self, state):
]
state.update({self.output[0]: compressed_document})
return state
-
+
# handling for pdf
elif input_keys[0] == "pdf":
loader = PyPDFLoader(source)
@@ -108,7 +109,7 @@ def execute(self, state):
]
state.update({self.output[0]: compressed_document})
return state
-
+
elif input_keys[0] == "json":
f = open(source)
compressed_document = [
@@ -116,7 +117,7 @@ def execute(self, state):
]
state.update({self.output[0]: compressed_document})
return state
-
+
elif input_keys[0] == "xml":
with open(source, "r", encoding="utf-8") as f:
data = f.read()
@@ -125,25 +126,29 @@ def execute(self, state):
]
state.update({self.output[0]: compressed_document})
return state
-
+
elif self.input == "pdf_dir":
pass
elif not source.startswith("http"):
title, minimized_body, link_urls, image_urls = cleanup_html(source, source)
parsed_content = f"Title: {title}, Body: {minimized_body}, Links: {link_urls}, Images: {image_urls}"
- compressed_document = [Document(page_content=parsed_content,
- metadata={"source": "local_dir"}
- )]
-
+ compressed_document = [
+ Document(page_content=parsed_content, metadata={"source": "local_dir"})
+ ]
+
elif self.useSoup:
response = requests.get(source)
if response.status_code == 200:
- title, minimized_body, link_urls, image_urls = cleanup_html(response.text, source)
+ title, minimized_body, link_urls, image_urls = cleanup_html(
+ response.text, source
+ )
parsed_content = f"Title: {title}, Body: {minimized_body}, Links: {link_urls}, Images: {image_urls}"
compressed_document = [Document(page_content=parsed_content)]
- else:
- print(f"Failed to retrieve contents from the webpage at url: {source}")
+ else:
+ self.logger.warning(
+ f"Failed to retrieve contents from the webpage at url: {source}"
+ )
else:
loader_kwargs = {}
@@ -153,14 +158,22 @@ def execute(self, state):
loader = ChromiumLoader([source], headless=self.headless, **loader_kwargs)
document = loader.load()
-
- title, minimized_body, link_urls, image_urls = cleanup_html(str(document[0].page_content), source)
+
+ title, minimized_body, link_urls, image_urls = cleanup_html(
+ str(document[0].page_content), source
+ )
parsed_content = f"Title: {title}, Body: {minimized_body}, Links: {link_urls}, Images: {image_urls}"
-
+
compressed_document = [
Document(page_content=parsed_content, metadata={"source": source})
]
- state.update({self.output[0]: compressed_document, self.output[1]: link_urls, self.output[2]: image_urls})
+ state.update(
+ {
+ self.output[0]: compressed_document,
+ self.output[1]: link_urls,
+ self.output[2]: image_urls,
+ }
+ )
- return state
\ No newline at end of file
+ return state
diff --git a/scrapegraphai/nodes/generate_answer_csv_node.py b/scrapegraphai/nodes/generate_answer_csv_node.py
index 53f7121b..7b5fbb14 100644
--- a/scrapegraphai/nodes/generate_answer_csv_node.py
+++ b/scrapegraphai/nodes/generate_answer_csv_node.py
@@ -1,14 +1,18 @@
"""
+gg
Module for generating the answer node
"""
+
# Imports from standard library
from typing import List, Optional
-from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
+from tqdm import tqdm
+
+from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode
@@ -23,15 +27,15 @@ class GenerateAnswerCSVNode(BaseNode):
Attributes:
llm_model: An instance of a language model client, configured for generating answers.
- node_name (str): The unique identifier name for the node, defaulting
+ node_name (str): The unique identifier name for the node, defaulting
to "GenerateAnswerNodeCsv".
- node_type (str): The type of the node, set to "node" indicating a
+ node_type (str): The type of the node, set to "node" indicating a
standard operational node.
Args:
- llm_model: An instance of the language model client (e.g., ChatOpenAI) used
+ llm_model: An instance of the language model client (e.g., ChatOpenAI) used
for generating answers.
- node_name (str, optional): The unique identifier name for the node.
+ node_name (str, optional): The unique identifier name for the node.
Defaults to "GenerateAnswerNodeCsv".
Methods:
@@ -39,8 +43,13 @@ class GenerateAnswerCSVNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
- node_name: str = "GenerateAnswer"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "GenerateAnswer",
+ ):
"""
Initializes the GenerateAnswerNodeCsv with a language model client and a node name.
Args:
@@ -49,8 +58,9 @@ def __init__(self, input: str, output: List[str], node_config: Optional[dict] =
"""
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]
- self.verbose = False if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state):
"""
@@ -71,8 +81,7 @@ def execute(self, state):
that the necessary information for generating an answer is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -120,21 +129,27 @@ def execute(self, state):
chains_dict = {}
# Use tqdm to add progress bar
- for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
+ for i, chunk in enumerate(
+ tqdm(doc, desc="Processing chunks", disable=not self.verbose)
+ ):
if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "format_instructions": format_instructions},
+ partial_variables={
+ "context": chunk.page_content,
+ "format_instructions": format_instructions,
+ },
)
else:
prompt = PromptTemplate(
template=template_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "chunk_id": i + 1,
- "format_instructions": format_instructions},
+ partial_variables={
+ "context": chunk.page_content,
+ "chunk_id": i + 1,
+ "format_instructions": format_instructions,
+ },
)
# Dynamically name the chains based on their index
@@ -153,8 +168,7 @@ def execute(self, state):
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm_model | output_parser
- answer = merge_chain.invoke(
- {"context": answer, "question": user_prompt})
+ answer = merge_chain.invoke({"context": answer, "question": user_prompt})
else:
# Chain
single_chain = list(chains_dict.values())[0]
diff --git a/scrapegraphai/nodes/generate_answer_node.py b/scrapegraphai/nodes/generate_answer_node.py
index 168ec4f3..b853951e 100644
--- a/scrapegraphai/nodes/generate_answer_node.py
+++ b/scrapegraphai/nodes/generate_answer_node.py
@@ -4,12 +4,14 @@
# Imports from standard library
from typing import List, Optional
-from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
+from tqdm import tqdm
+
+from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode
@@ -33,13 +35,19 @@ class GenerateAnswerNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
- node_name: str = "GenerateAnswer"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "GenerateAnswer",
+ ):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]
- self.verbose = True if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ True if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
@@ -58,8 +66,7 @@ def execute(self, state: dict) -> dict:
that the necessary information for generating an answer is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -107,21 +114,27 @@ def execute(self, state: dict) -> dict:
chains_dict = {}
# Use tqdm to add progress bar
- for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
+ for i, chunk in enumerate(
+ tqdm(doc, desc="Processing chunks", disable=not self.verbose)
+ ):
if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "format_instructions": format_instructions},
+ partial_variables={
+ "context": chunk.page_content,
+ "format_instructions": format_instructions,
+ },
)
else:
prompt = PromptTemplate(
template=template_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "chunk_id": i + 1,
- "format_instructions": format_instructions},
+ partial_variables={
+ "context": chunk.page_content,
+ "chunk_id": i + 1,
+ "format_instructions": format_instructions,
+ },
)
# Dynamically name the chains based on their index
@@ -140,8 +153,7 @@ def execute(self, state: dict) -> dict:
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm_model | output_parser
- answer = merge_chain.invoke(
- {"context": answer, "question": user_prompt})
+ answer = merge_chain.invoke({"context": answer, "question": user_prompt})
else:
# Chain
single_chain = list(chains_dict.values())[0]
diff --git a/scrapegraphai/nodes/generate_answer_omni_node.py b/scrapegraphai/nodes/generate_answer_omni_node.py
index fc2e8786..1cdd2042 100644
--- a/scrapegraphai/nodes/generate_answer_omni_node.py
+++ b/scrapegraphai/nodes/generate_answer_omni_node.py
@@ -4,12 +4,12 @@
# Imports from standard library
from typing import List, Optional
-from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
+from tqdm import tqdm
# Imports from the library
from .base_node import BaseNode
@@ -33,13 +33,19 @@ class GenerateAnswerOmniNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
- node_name: str = "GenerateAnswerOmni"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "GenerateAnswerOmni",
+ ):
super().__init__(node_name, "node", input, output, 3, node_config)
self.llm_model = node_config["llm_model"]
- self.verbose = False if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
@@ -58,8 +64,7 @@ def execute(self, state: dict) -> dict:
that the necessary information for generating an answer is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -112,22 +117,28 @@ def execute(self, state: dict) -> dict:
chains_dict = {}
# Use tqdm to add progress bar
- for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
+ for i, chunk in enumerate(
+ tqdm(doc, desc="Processing chunks", disable=not self.verbose)
+ ):
if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "format_instructions": format_instructions,
- "img_desc": imag_desc},
+ partial_variables={
+ "context": chunk.page_content,
+ "format_instructions": format_instructions,
+ "img_desc": imag_desc,
+ },
)
else:
prompt = PromptTemplate(
template=template_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "chunk_id": i + 1,
- "format_instructions": format_instructions},
+ partial_variables={
+ "context": chunk.page_content,
+ "chunk_id": i + 1,
+ "format_instructions": format_instructions,
+ },
)
# Dynamically name the chains based on their index
@@ -149,8 +160,7 @@ def execute(self, state: dict) -> dict:
},
)
merge_chain = merge_prompt | self.llm_model | output_parser
- answer = merge_chain.invoke(
- {"context": answer, "question": user_prompt})
+ answer = merge_chain.invoke({"context": answer, "question": user_prompt})
else:
# Chain
single_chain = list(chains_dict.values())[0]
diff --git a/scrapegraphai/nodes/generate_answer_pdf_node.py b/scrapegraphai/nodes/generate_answer_pdf_node.py
index 31839d22..ec5ef080 100644
--- a/scrapegraphai/nodes/generate_answer_pdf_node.py
+++ b/scrapegraphai/nodes/generate_answer_pdf_node.py
@@ -1,14 +1,17 @@
"""
Module for generating the answer node
"""
+
# Imports from standard library
from typing import List, Optional
-from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
+from tqdm import tqdm
+
+from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode
@@ -23,15 +26,15 @@ class GenerateAnswerPDFNode(BaseNode):
Attributes:
llm: An instance of a language model client, configured for generating answers.
- node_name (str): The unique identifier name for the node, defaulting
+ node_name (str): The unique identifier name for the node, defaulting
to "GenerateAnswerNodePDF".
- node_type (str): The type of the node, set to "node" indicating a
+ node_type (str): The type of the node, set to "node" indicating a
standard operational node.
Args:
- llm: An instance of the language model client (e.g., ChatOpenAI) used
+ llm: An instance of the language model client (e.g., ChatOpenAI) used
for generating answers.
- node_name (str, optional): The unique identifier name for the node.
+ node_name (str, optional): The unique identifier name for the node.
Defaults to "GenerateAnswerNodePDF".
Methods:
@@ -39,8 +42,13 @@ class GenerateAnswerPDFNode(BaseNode):
updating the state with the generated answer under the 'answer' key.
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
- node_name: str = "GenerateAnswer"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "GenerateAnswer",
+ ):
"""
Initializes the GenerateAnswerNodePDF with a language model client and a node name.
Args:
@@ -49,8 +57,9 @@ def __init__(self, input: str, output: List[str], node_config: Optional[dict] =
"""
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm"]
- self.verbose = False if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state):
"""
@@ -71,8 +80,7 @@ def execute(self, state):
that the necessary information for generating an answer is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -120,21 +128,27 @@ def execute(self, state):
chains_dict = {}
# Use tqdm to add progress bar
- for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
+ for i, chunk in enumerate(
+ tqdm(doc, desc="Processing chunks", disable=not self.verbose)
+ ):
if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "format_instructions": format_instructions},
+ partial_variables={
+ "context": chunk.page_content,
+ "format_instructions": format_instructions,
+ },
)
else:
prompt = PromptTemplate(
template=template_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "chunk_id": i + 1,
- "format_instructions": format_instructions},
+ partial_variables={
+ "context": chunk.page_content,
+ "chunk_id": i + 1,
+ "format_instructions": format_instructions,
+ },
)
# Dynamically name the chains based on their index
@@ -153,8 +167,7 @@ def execute(self, state):
partial_variables={"format_instructions": format_instructions},
)
merge_chain = merge_prompt | self.llm_model | output_parser
- answer = merge_chain.invoke(
- {"context": answer, "question": user_prompt})
+ answer = merge_chain.invoke({"context": answer, "question": user_prompt})
else:
# Chain
single_chain = list(chains_dict.values())[0]
diff --git a/scrapegraphai/nodes/generate_scraper_node.py b/scrapegraphai/nodes/generate_scraper_node.py
index 804635de..0c64b64a 100644
--- a/scrapegraphai/nodes/generate_scraper_node.py
+++ b/scrapegraphai/nodes/generate_scraper_node.py
@@ -4,12 +4,14 @@
# Imports from standard library
from typing import List, Optional
-from tqdm import tqdm
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableParallel
+from tqdm import tqdm
+
+from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode
@@ -36,15 +38,24 @@ class GenerateScraperNode(BaseNode):
"""
- def __init__(self, input: str, output: List[str], library: str, website: str,
- node_config: Optional[dict]=None, node_name: str = "GenerateScraper"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ library: str,
+ website: str,
+ node_config: Optional[dict] = None,
+ node_name: str = "GenerateScraper",
+ ):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]
self.library = library
self.source = website
-
- self.verbose = False if node_config is None else node_config.get("verbose", False)
+
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
@@ -62,8 +73,7 @@ def execute(self, state: dict) -> dict:
that the necessary information for generating an answer is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -92,17 +102,20 @@ def execute(self, state: dict) -> dict:
"""
print("source:", self.source)
if len(doc) > 1:
- raise NotImplementedError("Currently GenerateScraperNode cannot handle more than 1 context chunks")
+ raise NotImplementedError(
+ "Currently GenerateScraperNode cannot handle more than 1 context chunks"
+ )
else:
template = template_no_chunks
prompt = PromptTemplate(
template=template,
input_variables=["question"],
- partial_variables={"context": doc[0],
- "library": self.library,
- "source": self.source
- },
+ partial_variables={
+ "context": doc[0],
+ "library": self.library,
+ "source": self.source,
+ },
)
map_chain = prompt | self.llm_model | output_parser
diff --git a/scrapegraphai/nodes/get_probable_tags_node.py b/scrapegraphai/nodes/get_probable_tags_node.py
index e970c285..a26ded38 100644
--- a/scrapegraphai/nodes/get_probable_tags_node.py
+++ b/scrapegraphai/nodes/get_probable_tags_node.py
@@ -3,16 +3,19 @@
"""
from typing import List, Optional
+
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
+
+from ..utils.logging import get_logger
from .base_node import BaseNode
class GetProbableTagsNode(BaseNode):
"""
- A node that utilizes a language model to identify probable HTML tags within a document that
+ A node that utilizes a language model to identify probable HTML tags within a document that
are likely to contain the information relevant to a user's query. This node generates a prompt
- describing the task, submits it to the language model, and processes the output to produce a
+ describing the task, submits it to the language model, and processes the output to produce a
list of probable tags.
Attributes:
@@ -25,16 +28,24 @@ class GetProbableTagsNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GetProbableTags".
"""
- def __init__(self, input: str, output: List[str], model_config: dict,
- node_name: str = "GetProbableTags"):
- super().__init__(node_name, "node", input, output, 2, model_config)
-
- self.llm_model = model_config["llm_model"]
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: dict,
+ node_name: str = "GetProbableTags",
+ ):
+ super().__init__(node_name, "node", input, output, 2, node_config)
+
+ self.llm_model = node_config["llm_model"]
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
- Generates a list of probable HTML tags based on the user's input and updates the state
- with this list. The method constructs a prompt for the language model, submits it, and
+ Generates a list of probable HTML tags based on the user's input and updates the state
+ with this list. The method constructs a prompt for the language model, submits it, and
parses the output to identify probable tags.
Args:
@@ -49,7 +60,7 @@ def execute(self, state: dict) -> dict:
necessary information for generating tag predictions is missing.
"""
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -76,7 +87,9 @@ def execute(self, state: dict) -> dict:
template=template,
input_variables=["question"],
partial_variables={
- "format_instructions": format_instructions, "webpage": url},
+ "format_instructions": format_instructions,
+ "webpage": url,
+ },
)
# Execute the chain to get probable tags
diff --git a/scrapegraphai/nodes/graph_iterator_node.py b/scrapegraphai/nodes/graph_iterator_node.py
index 8a71319a..cd932986 100644
--- a/scrapegraphai/nodes/graph_iterator_node.py
+++ b/scrapegraphai/nodes/graph_iterator_node.py
@@ -8,6 +8,7 @@
from tqdm.asyncio import tqdm
+from ..utils.logging import get_logger
from .base_node import BaseNode
@@ -59,8 +60,9 @@ def execute(self, state: dict) -> dict:
"""
batchsize = self.node_config.get("batchsize", _default_batchsize)
- if self.verbose:
- print(f"--- Executing {self.node_name} Node with batchsize {batchsize} ---")
+ self.logger.info(
+ f"--- Executing {self.node_name} Node with batchsize {batchsize} ---"
+ )
try:
eventloop = asyncio.get_event_loop()
diff --git a/scrapegraphai/nodes/image_to_text_node.py b/scrapegraphai/nodes/image_to_text_node.py
index 49e99f72..7e7507a9 100644
--- a/scrapegraphai/nodes/image_to_text_node.py
+++ b/scrapegraphai/nodes/image_to_text_node.py
@@ -3,6 +3,8 @@
"""
from typing import List, Optional
+
+from ..utils.logging import get_logger
from .base_node import BaseNode
@@ -22,16 +24,18 @@ class ImageToTextNode(BaseNode):
"""
def __init__(
- self,
- input: str,
- output: List[str],
- node_config: Optional[dict]=None,
- node_name: str = "ImageToText",
- ):
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "ImageToText",
+ ):
super().__init__(node_name, "node", input, output, 1, node_config)
self.llm_model = node_config["llm_model"]
- self.verbose = False if node_config is None else node_config.get("verbose", False)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
self.max_images = 5 if node_config is None else node_config.get("max_images", 5)
def execute(self, state: dict) -> dict:
@@ -47,9 +51,8 @@ def execute(self, state: dict) -> dict:
dict: The updated state with the input key containing the text extracted from the image.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
-
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
+
input_keys = self.get_input_keys(state)
input_data = [state[key] for key in input_keys]
urls = input_data[0]
@@ -62,9 +65,9 @@ def execute(self, state: dict) -> dict:
# Skip the image-to-text conversion
if self.max_images < 1:
return state
-
+
img_desc = []
- for url in urls[:self.max_images]:
+ for url in urls[: self.max_images]:
try:
text_answer = self.llm_model.run(url)
except Exception as e:
diff --git a/scrapegraphai/nodes/merge_answers_node.py b/scrapegraphai/nodes/merge_answers_node.py
index e873309f..f64c3a9c 100644
--- a/scrapegraphai/nodes/merge_answers_node.py
+++ b/scrapegraphai/nodes/merge_answers_node.py
@@ -8,6 +8,9 @@
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
+from tqdm import tqdm
+
+from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode
@@ -28,17 +31,23 @@ class MergeAnswersNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
- node_name: str = "MergeAnswers"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "MergeAnswers",
+ ):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]
- self.verbose = False if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
- Executes the node's logic to merge the answers from multiple graph instances into a
+ Executes the node's logic to merge the answers from multiple graph instances into a
single answer.
Args:
@@ -53,8 +62,7 @@ def execute(self, state: dict) -> dict:
that the necessary information for generating an answer is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
diff --git a/scrapegraphai/nodes/parse_node.py b/scrapegraphai/nodes/parse_node.py
index 39e40a23..77074d65 100644
--- a/scrapegraphai/nodes/parse_node.py
+++ b/scrapegraphai/nodes/parse_node.py
@@ -3,17 +3,20 @@
"""
from typing import List, Optional
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_transformers import Html2TextTransformer
+
+from ..utils.logging import get_logger
from .base_node import BaseNode
class ParseNode(BaseNode):
"""
- A node responsible for parsing HTML content from a document.
+ A node responsible for parsing HTML content from a document.
The parsed content is split into chunks for further processing.
- This node enhances the scraping workflow by allowing for targeted extraction of
+ This node enhances the scraping workflow by allowing for targeted extraction of
content, thereby optimizing the processing of large HTML documents.
Attributes:
@@ -26,13 +29,23 @@ class ParseNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "Parse".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "Parse"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "Parse",
+ ):
super().__init__(node_name, "node", input, output, 1, node_config)
- self.verbose = False if node_config is None else node_config.get("verbose", False)
- self.parse_html = True if node_config is None else node_config.get("parse_html", True)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
+ self.parse_html = (
+ True if node_config is None else node_config.get("parse_html", True)
+ )
- def execute(self, state: dict) -> dict:
+ def execute(self, state: dict) -> dict:
"""
Executes the node's logic to parse the HTML document content and split it into chunks.
@@ -48,8 +61,7 @@ def execute(self, state: dict) -> dict:
necessary information for parsing the content is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -65,12 +77,11 @@ def execute(self, state: dict) -> dict:
# Parse the document
docs_transformed = input_data[0]
if self.parse_html:
- docs_transformed = Html2TextTransformer(
- ).transform_documents(input_data[0])
+ docs_transformed = Html2TextTransformer().transform_documents(input_data[0])
docs_transformed = docs_transformed[0]
chunks = text_splitter.split_text(docs_transformed.page_content)
-
+
state.update({self.output[0]: chunks})
return state
diff --git a/scrapegraphai/nodes/rag_node.py b/scrapegraphai/nodes/rag_node.py
index 27d97b6e..6d26bd1c 100644
--- a/scrapegraphai/nodes/rag_node.py
+++ b/scrapegraphai/nodes/rag_node.py
@@ -3,12 +3,17 @@
"""
from typing import List, Optional
+
from langchain.docstore.document import Document
from langchain.retrievers import ContextualCompressionRetriever
-from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline
+from langchain.retrievers.document_compressors import (
+ DocumentCompressorPipeline,
+ EmbeddingsFilter,
+)
from langchain_community.document_transformers import EmbeddingsRedundantFilter
from langchain_community.vectorstores import FAISS
+from ..utils.logging import get_logger
from .base_node import BaseNode
@@ -31,13 +36,20 @@ class RAGNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "Parse".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None, node_name: str = "RAG"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "RAG",
+ ):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]
self.embedder_model = node_config.get("embedder_model", None)
- self.verbose = False if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
@@ -56,8 +68,7 @@ def execute(self, state: dict) -> dict:
necessary information for compressing the content is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -79,15 +90,15 @@ def execute(self, state: dict) -> dict:
)
chunked_docs.append(doc)
- if self.verbose:
- print("--- (updated chunks metadata) ---")
+ self.logger.info("--- (updated chunks metadata) ---")
# check if embedder_model is provided, if not use llm_model
- self.embedder_model = self.embedder_model if self.embedder_model else self.llm_model
+ self.embedder_model = (
+ self.embedder_model if self.embedder_model else self.llm_model
+ )
embeddings = self.embedder_model
- retriever = FAISS.from_documents(
- chunked_docs, embeddings).as_retriever()
+ retriever = FAISS.from_documents(chunked_docs, embeddings).as_retriever()
redundant_filter = EmbeddingsRedundantFilter(embeddings=embeddings)
# similarity_threshold could be set, now k=20
@@ -107,9 +118,7 @@ def execute(self, state: dict) -> dict:
compressed_docs = compression_retriever.invoke(user_prompt)
- if self.verbose:
- print("--- (tokens compressed and vector stored) ---")
+ self.logger.info("--- (tokens compressed and vector stored) ---")
state.update({self.output[0]: compressed_docs})
return state
-
diff --git a/scrapegraphai/nodes/robots_node.py b/scrapegraphai/nodes/robots_node.py
index 62d24d96..e5240d42 100644
--- a/scrapegraphai/nodes/robots_node.py
+++ b/scrapegraphai/nodes/robots_node.py
@@ -4,11 +4,14 @@
from typing import List, Optional
from urllib.parse import urlparse
-from langchain_community.document_loaders import AsyncChromiumLoader
-from langchain.prompts import PromptTemplate
+
from langchain.output_parsers import CommaSeparatedListOutputParser
-from .base_node import BaseNode
+from langchain.prompts import PromptTemplate
+from langchain_community.document_loaders import AsyncChromiumLoader
+
from ..helpers import robots_dictionary
+from ..utils.logging import get_logger
+from .base_node import BaseNode
class RobotsNode(BaseNode):
@@ -34,16 +37,21 @@ class RobotsNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "Robots".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict]=None,
-
- node_name: str = "Robots"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "Robots",
+ ):
super().__init__(node_name, "node", input, output, 1)
self.llm_model = node_config["llm_model"]
self.force_scraping = force_scraping
- self.verbose = True if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ True if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
@@ -65,8 +73,7 @@ def execute(self, state: dict) -> dict:
scraping is not enforced.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -91,21 +98,21 @@ def execute(self, state: dict) -> dict:
"""
if not source.startswith("http"):
- raise ValueError(
- "Operation not allowed")
+ raise ValueError("Operation not allowed")
else:
parsed_url = urlparse(source)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
loader = AsyncChromiumLoader(f"{base_url}/robots.txt")
document = loader.load()
- if "ollama" in self.llm_model.model_name:
- self.llm_model.model_name = self.llm_model.model_name.split(
- "/")[-1]
- model = self.llm_model.model_name.split("/")[-1]
+ if "ollama" in self.llm_model["model_name"]:
+ self.llm_model["model_name"] = self.llm_model["model_name"].split("/")[
+ -1
+ ]
+ model = self.llm_model["model_name"].split("/")[-1]
else:
- model = self.llm_model.model_name
+ model = self.llm_model["model_name"]
try:
agent = robots_dictionary[model]
@@ -115,27 +122,25 @@ def execute(self, state: dict) -> dict:
prompt = PromptTemplate(
template=template,
input_variables=["path"],
- partial_variables={"context": document,
- "agent": agent
- },
+ partial_variables={"context": document, "agent": agent},
)
chain = prompt | self.llm_model | output_parser
is_scrapable = chain.invoke({"path": source})[0]
if "no" in is_scrapable:
- if self.verbose:
- print("\033[31m(Scraping this website is not allowed)\033[0m")
+ self.logger.warning(
+ "\033[31m(Scraping this website is not allowed)\033[0m"
+ )
if not self.force_scraping:
- raise ValueError(
- 'The website you selected is not scrapable')
+ raise ValueError("The website you selected is not scrapable")
else:
- if self.verbose:
- print("\033[33m(WARNING: Scraping this website is not allowed but you decided to force it)\033[0m")
+ self.logger.warning(
+ "\033[33m(WARNING: Scraping this website is not allowed but you decided to force it)\033[0m"
+ )
else:
- if self.verbose:
- print("\033[32m(Scraping this website is allowed)\033[0m")
+ self.logger.warning("\033[32m(Scraping this website is allowed)\033[0m")
state.update({self.output[0]: is_scrapable})
return state
diff --git a/scrapegraphai/nodes/search_internet_node.py b/scrapegraphai/nodes/search_internet_node.py
index 87f8dcb2..9fa4a8f5 100644
--- a/scrapegraphai/nodes/search_internet_node.py
+++ b/scrapegraphai/nodes/search_internet_node.py
@@ -3,8 +3,11 @@
"""
from typing import List, Optional
+
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
+
+from ..utils.logging import get_logger
from ..utils.research_web import search_on_web
from .base_node import BaseNode
@@ -27,13 +30,19 @@ class SearchInternetNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "SearchInternet".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
- node_name: str = "SearchInternet"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "SearchInternet",
+ ):
super().__init__(node_name, "node", input, output, 1, node_config)
self.llm_model = node_config["llm_model"]
- self.verbose = False if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
self.max_results = node_config.get("max_results", 3)
def execute(self, state: dict) -> dict:
@@ -55,8 +64,7 @@ def execute(self, state: dict) -> dict:
necessary information for generating the answer is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
input_keys = self.get_input_keys(state)
@@ -87,11 +95,9 @@ def execute(self, state: dict) -> dict:
search_answer = search_prompt | self.llm_model | output_parser
search_query = search_answer.invoke({"user_prompt": user_prompt})[0]
- if self.verbose:
- print(f"Search Query: {search_query}")
+ self.logger.info(f"Search Query: {search_query}")
- answer = search_on_web(
- query=search_query, max_results=self.max_results)
+ answer = search_on_web(query=search_query, max_results=self.max_results)
if len(answer) == 0:
# raise an exception if no answer is found
diff --git a/scrapegraphai/nodes/search_link_node.py b/scrapegraphai/nodes/search_link_node.py
index b15e8d26..b19095a0 100644
--- a/scrapegraphai/nodes/search_link_node.py
+++ b/scrapegraphai/nodes/search_link_node.py
@@ -4,13 +4,14 @@
# Imports from standard library
from typing import List, Optional
-from tqdm import tqdm
-
# Imports from Langchain
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableParallel
+from tqdm import tqdm
+
+from ..utils.logging import get_logger
# Imports from the library
from .base_node import BaseNode
@@ -33,13 +34,19 @@ class SearchLinkNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
- node_name: str = "GenerateLinks"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "GenerateLinks",
+ ):
super().__init__(node_name, "node", input, output, 1, node_config)
self.llm_model = node_config["llm_model"]
- self.verbose = False if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
@@ -58,8 +65,7 @@ def execute(self, state: dict) -> dict:
necessary information for generating the answer is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -93,7 +99,13 @@ def execute(self, state: dict) -> dict:
"""
relevant_links = []
- for i, chunk in enumerate(tqdm(parsed_content_chunks, desc="Processing chunks", disable=not self.verbose)):
+ for i, chunk in enumerate(
+ tqdm(
+ parsed_content_chunks,
+ desc="Processing chunks",
+ disable=not self.verbose,
+ )
+ ):
merge_prompt = PromptTemplate(
template=prompt_relevant_links,
input_variables=["content", "user_prompt"],
@@ -101,7 +113,8 @@ def execute(self, state: dict) -> dict:
merge_chain = merge_prompt | self.llm_model | output_parser
# merge_chain = merge_prompt | self.llm_model
answer = merge_chain.invoke(
- {"content": chunk.page_content, "user_prompt": user_prompt})
+ {"content": chunk.page_content, "user_prompt": user_prompt}
+ )
relevant_links += answer
state.update({self.output[0]: relevant_links})
return state
diff --git a/scrapegraphai/nodes/search_node_with_context.py b/scrapegraphai/nodes/search_node_with_context.py
index 17437f6f..62de184a 100644
--- a/scrapegraphai/nodes/search_node_with_context.py
+++ b/scrapegraphai/nodes/search_node_with_context.py
@@ -3,9 +3,11 @@
"""
from typing import List, Optional
-from tqdm import tqdm
+
from langchain.output_parsers import CommaSeparatedListOutputParser
from langchain.prompts import PromptTemplate
+from tqdm import tqdm
+
from .base_node import BaseNode
@@ -27,12 +29,18 @@ class SearchLinksWithContext(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "GenerateAnswer".
"""
- def __init__(self, input: str, output: List[str], node_config: Optional[dict] = None,
- node_name: str = "GenerateAnswer"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "GenerateAnswer",
+ ):
super().__init__(node_name, "node", input, output, 2, node_config)
self.llm_model = node_config["llm_model"]
- self.verbose = True if node_config is None else node_config.get(
- "verbose", False)
+ self.verbose = (
+ True if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
@@ -51,8 +59,7 @@ def execute(self, state: dict) -> dict:
that the necessary information for generating an answer is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
@@ -90,25 +97,30 @@ def execute(self, state: dict) -> dict:
result = []
# Use tqdm to add progress bar
- for i, chunk in enumerate(tqdm(doc, desc="Processing chunks", disable=not self.verbose)):
+ for i, chunk in enumerate(
+ tqdm(doc, desc="Processing chunks", disable=not self.verbose)
+ ):
if len(doc) == 1:
prompt = PromptTemplate(
template=template_no_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "format_instructions": format_instructions},
+ partial_variables={
+ "context": chunk.page_content,
+ "format_instructions": format_instructions,
+ },
)
else:
prompt = PromptTemplate(
template=template_chunks,
input_variables=["question"],
- partial_variables={"context": chunk.page_content,
- "chunk_id": i + 1,
- "format_instructions": format_instructions},
+ partial_variables={
+ "context": chunk.page_content,
+ "chunk_id": i + 1,
+ "format_instructions": format_instructions,
+ },
)
- result.extend(
- prompt | self.llm_model | output_parser)
+ result.extend(prompt | self.llm_model | output_parser)
state["urls"] = result
return state
diff --git a/scrapegraphai/nodes/text_to_speech_node.py b/scrapegraphai/nodes/text_to_speech_node.py
index d9fe7ca4..59e3fb8b 100644
--- a/scrapegraphai/nodes/text_to_speech_node.py
+++ b/scrapegraphai/nodes/text_to_speech_node.py
@@ -3,6 +3,8 @@
"""
from typing import List, Optional
+
+from ..utils.logging import get_logger
from .base_node import BaseNode
@@ -21,12 +23,19 @@ class TextToSpeechNode(BaseNode):
node_name (str): The unique identifier name for the node, defaulting to "TextToSpeech".
"""
- def __init__(self, input: str, output: List[str],
- node_config: Optional[dict]=None, node_name: str = "TextToSpeech"):
+ def __init__(
+ self,
+ input: str,
+ output: List[str],
+ node_config: Optional[dict] = None,
+ node_name: str = "TextToSpeech",
+ ):
super().__init__(node_name, "node", input, output, 1, node_config)
self.tts_model = node_config["tts_model"]
- self.verbose = False if node_config is None else node_config.get("verbose", False)
+ self.verbose = (
+ False if node_config is None else node_config.get("verbose", False)
+ )
def execute(self, state: dict) -> dict:
"""
@@ -35,7 +44,7 @@ def execute(self, state: dict) -> dict:
Args:
state (dict): The current state of the graph. The input keys will be used to fetch the
correct data types from the state.
-
+
Returns:
dict: The updated state with the output key containing the audio generated from the text.
@@ -44,8 +53,7 @@ def execute(self, state: dict) -> dict:
necessary information for generating the audio is missing.
"""
- if self.verbose:
- print(f"--- Executing {self.node_name} Node ---")
+ self.logger.info(f"--- Executing {self.node_name} Node ---")
# Interpret input keys based on the provided input expression
input_keys = self.get_input_keys(state)
diff --git a/scrapegraphai/utils/__init__.py b/scrapegraphai/utils/__init__.py
index 72a8b96c..ee647466 100644
--- a/scrapegraphai/utils/__init__.py
+++ b/scrapegraphai/utils/__init__.py
@@ -9,3 +9,4 @@
from .save_audio_from_bytes import save_audio_from_bytes
from .sys_dynamic_import import dynamic_import, srcfile_import
from .cleanup_html import cleanup_html
+from .logging import *
\ No newline at end of file
diff --git a/scrapegraphai/utils/logging.py b/scrapegraphai/utils/logging.py
new file mode 100644
index 00000000..b4a677dd
--- /dev/null
+++ b/scrapegraphai/utils/logging.py
@@ -0,0 +1,139 @@
+"""A centralized logging system for any library
+
+source code inspired by https://gist.github.com/DiTo97/9a0377f24236b66134eb96da1ec1693f
+"""
+
+import logging
+import os
+import sys
+import threading
+from functools import lru_cache
+
+
+_library_name = __name__.split(".", maxsplit=1)[0]
+
+_default_handler = None
+_default_logging_level = logging.WARNING
+
+_semaphore = threading.Lock()
+
+
+def _get_library_root_logger() -> logging.Logger:
+ return logging.getLogger(_library_name)
+
+
+def _set_library_root_logger() -> None:
+ global _default_handler
+
+ with _semaphore:
+ if _default_handler:
+ return
+
+ _default_handler = logging.StreamHandler() # sys.stderr as stream
+
+ # https://github.com/pyinstaller/pyinstaller/issues/7334#issuecomment-1357447176
+ if sys.stderr is None:
+ sys.stderr = open(os.devnull, "w")
+
+ _default_handler.flush = sys.stderr.flush
+
+ library_root_logger = _get_library_root_logger()
+ library_root_logger.addHandler(_default_handler)
+ library_root_logger.setLevel(_default_logging_level)
+ library_root_logger.propagate = False
+
+
+def get_logger(name: str | None = None) -> logging.Logger:
+ _set_library_root_logger()
+ return logging.getLogger(name or _library_name)
+
+
+def get_verbosity() -> int:
+ _set_library_root_logger()
+ return _get_library_root_logger().getEffectiveLevel()
+
+
+def set_verbosity(verbosity: int) -> None:
+ _set_library_root_logger()
+ _get_library_root_logger().setLevel(verbosity)
+
+
+def set_verbosity_debug() -> None:
+ set_verbosity(logging.DEBUG)
+
+
+def set_verbosity_info() -> None:
+ set_verbosity(logging.INFO)
+
+
+def set_verbosity_warning() -> None:
+ set_verbosity(logging.WARNING)
+
+
+def set_verbosity_error() -> None:
+ set_verbosity(logging.ERROR)
+
+
+def set_verbosity_fatal() -> None:
+ set_verbosity(logging.FATAL)
+
+
+def set_handler(handler: logging.Handler) -> None:
+ _set_library_root_logger()
+
+ assert handler is not None
+
+ _get_library_root_logger().addHandler(handler)
+
+
+def set_default_handler() -> None:
+ set_handler(_default_handler)
+
+
+def unset_handler(handler: logging.Handler) -> None:
+ _set_library_root_logger()
+
+ assert handler is not None
+
+ _get_library_root_logger().removeHandler(handler)
+
+
+def unset_default_handler() -> None:
+ unset_handler(_default_handler)
+
+
+def set_propagation() -> None:
+ _get_library_root_logger().propagate = True
+
+
+def unset_propagation() -> None:
+ _get_library_root_logger().propagate = False
+
+
+def set_formatting() -> None:
+ """sets formatting for all handlers bound to the root logger
+
+ ```
+ [levelname|filename|line number] time >> message
+ ```
+ """
+ formatter = logging.Formatter(
+ "[%(levelname)s|%(filename)s:%(lineno)s] %(asctime)s >> %(message)s"
+ )
+
+ for handler in _get_library_root_logger().handlers:
+ handler.setFormatter(formatter)
+
+
+def unset_formatting() -> None:
+ for handler in _get_library_root_logger().handlers:
+ handler.setFormatter(None)
+
+
+@lru_cache(None)
+def warning_once(self, *args, **kwargs):
+ """emits warning logs with the same message only once"""
+ self.warning(*args, **kwargs)
+
+
+logging.Logger.warning_once = warning_once
diff --git a/tests/nodes/fetch_node_test.py b/tests/nodes/fetch_node_test.py
index a67f3dbb..47b8b7ee 100644
--- a/tests/nodes/fetch_node_test.py
+++ b/tests/nodes/fetch_node_test.py
@@ -1,19 +1,11 @@
-"""
-Module for testinh fetch_node
-"""
+import os
import pytest
from scrapegraphai.nodes import FetchNode
-
-@pytest.fixture
-def setup():
+def test_fetch_node_html():
"""
- setup
+ Run the tests
"""
- # ************************************************
- # Define the node
- # ************************************************
-
fetch_node = FetchNode(
input="url | local_dir",
output=["doc"],
@@ -22,21 +14,94 @@ def setup():
}
)
- return fetch_node
+ state = {
+ "url": "https://twitter.com/home"
+ }
-# ************************************************
-# Test the node
-# ************************************************
+ result = fetch_node.execute(state)
+ assert result is not None
-def test_fetch_node(setup):
+def test_fetch_node_json():
"""
Run the tests
"""
- state = {
- "url": "https://twitter.com/home"
+ FILE_NAME_JSON = "inputs/example.json"
+ curr_dir = os.path.dirname(os.path.realpath(__file__))
+ file_path_json = os.path.join(curr_dir, FILE_NAME_JSON)
+
+ state_json = {
+ "json": file_path_json
+ }
+
+ fetch_node_json = FetchNode(
+ input="json",
+ output=["doc"],
+ )
+
+ result_json = fetch_node_json.execute(state_json)
+
+ assert result_json is not None
+
+def test_fetch_node_xml():
+ """
+ Run the tests
+ """
+ FILE_NAME_XML = "inputs/books.xml"
+ curr_dir = os.path.dirname(os.path.realpath(__file__))
+ file_path_xml = os.path.join(curr_dir, FILE_NAME_XML)
+
+ state_xml = {
+ "xml": file_path_xml
}
- result = setup.execute(state)
+ fetch_node_xml = FetchNode(
+ input="xml",
+ output=["doc"],
+ )
- assert result is not None
+ result_xml = fetch_node_xml.execute(state_xml)
+
+ assert result_xml is not None
+
+def test_fetch_node_csv():
+ """
+ Run the tests
+ """
+ FILE_NAME_CSV = "inputs/username.csv"
+ curr_dir = os.path.dirname(os.path.realpath(__file__))
+ file_path_csv = os.path.join(curr_dir, FILE_NAME_CSV)
+
+ state_csv = {
+ "csv": file_path_csv # Definire un dizionario con la chiave "csv" e il valore come percorso del file CSV
+ }
+
+ fetch_node_csv = FetchNode(
+ input="csv",
+ output=["doc"],
+ )
+
+ result_csv = fetch_node_csv.execute(state_csv)
+
+ assert result_csv is not None
+
+def test_fetch_node_txt():
+ """
+ Run the tests
+ """
+ FILE_NAME_TXT = "inputs/plain_html_example.txt"
+ curr_dir = os.path.dirname(os.path.realpath(__file__))
+ file_path_txt = os.path.join(curr_dir, FILE_NAME_TXT)
+
+ state_txt = {
+ "txt": file_path_txt # Definire un dizionario con la chiave "txt" e il valore come percorso del file TXT
+ }
+
+ fetch_node_txt = FetchNode(
+ input="txt",
+ output=["doc"],
+ )
+
+ result_txt = fetch_node_txt.execute(state_txt)
+
+ assert result_txt is not None
diff --git a/tests/nodes/inputs/books.xml b/tests/nodes/inputs/books.xml
new file mode 100644
index 00000000..e3d1fe87
--- /dev/null
+++ b/tests/nodes/inputs/books.xml
@@ -0,0 +1,120 @@
+
+