From 4e16c9a81d1385045a03288f1a02ff4233b14ff9 Mon Sep 17 00:00:00 2001 From: duke147 <85614628@qq.com> Date: Wed, 5 Jun 2024 17:37:35 +0800 Subject: [PATCH] support ernie --- scrapegraphai/builders/graph_builder.py | 3 ++ scrapegraphai/graphs/abstract_graph.py | 8 ++++ scrapegraphai/models/ernie.py | 17 +++++++ tests/graphs/smart_scraper_ernie_test.py | 57 ++++++++++++++++++++++++ 4 files changed, 85 insertions(+) create mode 100644 scrapegraphai/models/ernie.py create mode 100644 tests/graphs/smart_scraper_ernie_test.py diff --git a/scrapegraphai/builders/graph_builder.py b/scrapegraphai/builders/graph_builder.py index 7280c50b..ab19a251 100644 --- a/scrapegraphai/builders/graph_builder.py +++ b/scrapegraphai/builders/graph_builder.py @@ -6,6 +6,7 @@ from langchain.chains import create_extraction_chain from ..models import OpenAI, Gemini from ..helpers import nodes_metadata, graph_schema +from ..models.ernie import Ernie class GraphBuilder: @@ -73,6 +74,8 @@ def _create_llm(self, llm_config: dict): return OpenAI(llm_params) elif "gemini" in llm_params["model"]: return Gemini(llm_params) + elif "ernie" in llm_params["model"]: + return Ernie(llm_params) raise ValueError("Model not supported") def _generate_nodes_description(self): diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index 7814efa8..b5f3a681 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -24,6 +24,7 @@ OpenAI, OneApi ) +from ..models.ernie import Ernie from ..utils.logging import set_verbosity_debug, set_verbosity_warning from ..helpers import models_tokens @@ -272,6 +273,13 @@ def _create_llm(self, llm_config: dict, chat=False) -> object: print("model not found, using default token size (8192)") self.model_token = 8192 return DeepSeek(llm_params) + elif "ernie" in llm_params["model"]: + try: + self.model_token = models_tokens["ernie"][llm_params["model"]] + except KeyError: + print("model not found, using default token size (8192)") + self.model_token = 8192 + return Ernie(llm_params) else: raise ValueError("Model provided by the configuration not supported") diff --git a/scrapegraphai/models/ernie.py b/scrapegraphai/models/ernie.py new file mode 100644 index 00000000..0b4701e1 --- /dev/null +++ b/scrapegraphai/models/ernie.py @@ -0,0 +1,17 @@ +""" +Ollama Module +""" +from langchain_community.chat_models import ErnieBotChat + + +class Ernie(ErnieBotChat): + """ + A wrapper for the ErnieBotChat class that provides default configuration + and could be extended with additional methods if needed. + + Args: + llm_config (dict): Configuration parameters for the language model. + """ + + def __init__(self, llm_config: dict): + super().__init__(**llm_config) diff --git a/tests/graphs/smart_scraper_ernie_test.py b/tests/graphs/smart_scraper_ernie_test.py new file mode 100644 index 00000000..5efd8d0b --- /dev/null +++ b/tests/graphs/smart_scraper_ernie_test.py @@ -0,0 +1,57 @@ +""" +Module for testing th smart scraper class +""" +import pytest +from scrapegraphai.graphs import SmartScraperGraph + + +@pytest.fixture +def graph_config(): + """ + Configuration of the graph + """ + return { + "llm": { + "model": "ernie-bot-turbo", + "ernie_client_id": "", + "ernie_client_secret": "", + "temperature": 0.1 + }, + "embeddings": { + "model": "ollama/nomic-embed-text", + "temperature": 0, + "base_url": "http://localhost:11434", + } + } + + +def test_scraping_pipeline(graph_config: dict): + """ + Start of the scraping pipeline + """ + smart_scraper_graph = SmartScraperGraph( + prompt="List me all the news with their description.", + source="https://perinim.github.io/projects", + config=graph_config + ) + + result = smart_scraper_graph.run() + + assert result is not None + + +def test_get_execution_info(graph_config: dict): + """ + Get the execution info + """ + smart_scraper_graph = SmartScraperGraph( + prompt="List me all the news with their description.", + source="https://perinim.github.io/projects", + config=graph_config + ) + + smart_scraper_graph.run() + + graph_exec_info = smart_scraper_graph.get_execution_info() + + assert graph_exec_info is not None \ No newline at end of file