Skip to content

Updates to fix #617 for OpenAI and Mistral models #643

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ google>=3.0.0
undetected-playwright>=0.3.0
semchunk>=1.0.1
browserbase>=0.3.0
mistral-common==1.3.4
5 changes: 3 additions & 2 deletions scrapegraphai/graphs/abstract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def handle_model(model_name, provider, token_key, default_token=8192):

known_models = {"chatgpt","gpt","openai", "azure_openai", "google_genai",
"ollama", "oneapi", "nvidia", "groq", "google_vertexai",
"bedrock", "mistralai", "hugging_face", "deepseek", "ernie", "fireworks"}
"bedrock", "mistral", "hugging_face", "deepseek", "ernie", "fireworks"}

if llm_params["model"].split("/")[0] not in known_models and llm_params["model"].split("-")[0] not in known_models:
raise ValueError(f"Model '{llm_params['model']}' is not supported")
Expand All @@ -164,7 +164,8 @@ def handle_model(model_name, provider, token_key, default_token=8192):
return handle_model(llm_params["model"], "google_vertexai", llm_params["model"])

elif "gpt-" in llm_params["model"]:
return handle_model(llm_params["model"], "openai", llm_params["model"])
model_name = llm_params["model"].split("/")[-1]
return handle_model(model_name, "openai", model_name)

elif "ollama" in llm_params["model"]:
model_name = llm_params["model"].split("ollama/")[-1]
Expand Down
2 changes: 1 addition & 1 deletion scrapegraphai/graphs/script_creator_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _create_graph(self) -> BaseGraph:
}
)
generate_scraper_node = GenerateScraperNode(
input="user_prompt & (doc)",
input="user_prompt & (parsed_doc)",
output=["answer"],
node_config={
"llm_model": self.llm_model,
Expand Down
16 changes: 13 additions & 3 deletions scrapegraphai/nodes/generate_scraper_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,19 @@ def execute(self, state: dict) -> dict:
TEMPLATE_NO_CHUNKS += self.additional_info

if len(doc) > 1:
raise NotImplementedError(
"Currently GenerateScraperNode cannot handle more than 1 context chunks"
)
# Short term partial fix for issue #543 (Context length exceeded)
# If there are more than one chunks returned by ParseNode we just use the first one
# on the basis that the structure of the remainder of the HTML page is probably
# very similar to the first chunk therefore the generated script should still work.
# The better fix is to generate multiple scripts then use the LLM to merge them.

#raise NotImplementedError(
# "Currently GenerateScraperNode cannot handle more than 1 context chunks"
#)
self.logger.warn(f"Warning: {self.node_name} Node provided with {len(doc)} chunks but can only "
"support 1, ignoring remaining chunks")
doc = [doc[0]]
template = TEMPLATE_NO_CHUNKS
else:
template = TEMPLATE_NO_CHUNKS

Expand Down
64 changes: 13 additions & 51 deletions scrapegraphai/nodes/parse_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
"""
from typing import List, Optional
from semchunk import chunk
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_mistralai import ChatMistralAI
#from langchain_openai import ChatOpenAI
#from langchain_ollama import ChatOllama
#from langchain_mistralai import ChatMistralAI
from langchain_community.document_transformers import Html2TextTransformer
from langchain_core.documents import Document
from .base_node import BaseNode
from ..utils.tokenizers.tokenizer_openai import num_tokens_openai
from ..utils.tokenizers.tokenizer_mistral import num_tokens_mistral
from ..utils.tokenization import chunk_text
#from ..utils.tokenizers.tokenizer_openai import num_tokens_openai
#from ..utils.tokenizers.tokenizer_mistral import num_tokens_mistral

class ParseNode(BaseNode):
"""
Expand Down Expand Up @@ -47,6 +48,7 @@ def __init__(
)

self.llm_model = node_config.get("llm_model")
self.chunk_size = node_config.get("chunk_size")

def execute(self, state: dict) -> dict:
"""
Expand Down Expand Up @@ -75,55 +77,15 @@ def execute(self, state: dict) -> dict:
docs_transformed = Html2TextTransformer().transform_documents(input_data[0])[0]
else:
docs_transformed = docs_transformed[0]

context_window = self.llm_model.name.split("/")[-1] * 0.9

if isinstance(self.llm_model, ChatOpenAI):
num_tokens = num_tokens_openai(docs_transformed.page_content)
def chunker(text):
from ..utils import chunk_text
return chunk_text(text, self.llm_model, self.chunk_size, use_semchunk=False)

chunks = []
num_chunks = num_tokens // context_window

if num_tokens % context_window != 0:
num_chunks += 1

for i in range(num_chunks):
start = i * context_window
end = (i + 1) * context_window
chunks.append(docs_transformed.page_content[start:end])

elif isinstance(self.llm_model, ChatMistralAI):
model_name = self.llm_model.name.split("/")[-1] # Extract model name
num_tokens = num_tokens_mistral(docs_transformed.page_content, model_name)

chunks = []
num_chunks = num_tokens // context_window

if num_tokens % context_window != 0:
num_chunks += 1

for i in range(num_chunks):
start = i * context_window
end = (i + 1) * context_window
chunks.append(docs_transformed.page_content[start:end])

elif isinstance(self.llm_model, ChatOllama):
# TODO: Implement ChatOllama tokenization logic
print("Ollama model processing not yet implemented.")
if isinstance(docs_transformed, Document):
chunks = chunker(docs_transformed.page_content)
else:
chunk_size = self.node_config.get("chunk_size", 4096)
chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))

if isinstance(docs_transformed, Document):
chunks = chunk(text=docs_transformed.page_content,
chunk_size=chunk_size,
token_counter=lambda text: len(text.split()),
memoize=False)
else:
chunks = chunk(text=docs_transformed,
chunk_size=chunk_size,
token_counter=lambda text: len(text.split()),
memoize=False)
chunks = chunker(docs_transformed)

state.update({self.output[0]: chunks})

Expand Down
5 changes: 3 additions & 2 deletions scrapegraphai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
from .cleanup_html import cleanup_html
from .logging import *
from .convert_to_md import convert_to_md
from .tokenizers.tokenizer_openai import num_tokens_openai
from .tokenizers.tokenizer_mistral import num_tokens_mistral
#from .tokenizers.tokenizer_openai import num_tokens_openai
#from .tokenizers.tokenizer_mistral import num_tokens_mistral
from .tokenization import chunk_text
38 changes: 0 additions & 38 deletions scrapegraphai/utils/token_calculator.py

This file was deleted.

99 changes: 99 additions & 0 deletions scrapegraphai/utils/tokenization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""
Module for counting tokens and splitting text into chunks
"""
from typing import List
import tiktoken
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_mistralai import ChatMistralAI
from ..helpers.models_tokens import models_tokens
from langchain_core.language_models.chat_models import BaseChatModel
from .logging import get_logger

def chunk_text(text: str, llm_model: BaseChatModel, chunk_size: int, use_semchunk=False) -> List[str]:
"""
Truncates text into chunks that are small enough to be processed by specified llm models.

Args:
text (str): The input text to be truncated.
llm_model (BaseChatModel): The langchain chat model object.
chunk_size (int): Number of tokens per chunk allowed.
use_semchunk: Whether to use semchunk to split the text or use a simple token count
based approach.

Returns:
List[str]: A list of text chunks, each within the token limit of the specified model.

Example:
>>> chunk_text("This is a sample text for truncation.", openai_model)
["This is a sample text", "for truncation."]

This function ensures that each chunk of text can be tokenized
by the specified model without exceeding the model's token limit.
"""



if isinstance(llm_model, ChatOpenAI):
from .tokenizers.tokenizer_openai import num_tokens_openai
num_tokens_fn = num_tokens_openai

elif isinstance(llm_model, ChatMistralAI):
from .tokenizers.tokenizer_mistral import num_tokens_mistral
num_tokens_fn = num_tokens_mistral

elif isinstance(llm_model, ChatOllama):
from .tokenizers.tokenizer_ollama import num_tokens_ollama
num_tokens_fn = num_tokens_ollama

else:
raise NotImplementedError(f"There is no tokenization implementation for model '{llm_model}'")


if use_semchunk:
def count_tokens(text):
return token_count_fn(text, llm_model)

chunk_size = min(chunk_size - 500, int(chunk_size * 0.9))

chunks = chunk(text=text,
chunk_size=chunk_size,
token_counter=count_tokens,
memoize=False)
return chunks

else:

num_tokens = num_tokens_fn(text, llm_model)

chunks = []
num_chunks = num_tokens // chunk_size

if num_tokens % chunk_size != 0:
num_chunks += 1

for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
chunks.append(text[start:end])

return chunks





#################
# previous chunking code
#################

#encoding = tiktoken.get_encoding(encoding_name)
#max_tokens = min(models_tokens[model] - 500, int(models_tokens[model] * 0.9))
#encoded_text = encoding.encode(text)

#chunks = [encoded_text[i:i + max_tokens]
# for i in range(0, len(encoded_text), max_tokens)]

#result = [encoding.decode(chunk) for chunk in chunks]

#return result
24 changes: 20 additions & 4 deletions scrapegraphai/utils/tokenizers/tokenizer_mistral.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,45 @@
"""
Tokenization utilities for Mistral models
"""
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.instruct.tool_calls import Function, Tool
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from langchain_core.language_models.chat_models import BaseChatModel
from ..logging import get_logger

def num_tokens_mistral(text: str, model_name: str) -> int:

def num_tokens_mistral(text: str, llm_model:BaseChatModel) -> int:
"""
Estimate the number of tokens in a given text using Mistral's tokenization method,
adjusted for different Mistral models.

Args:
text (str): The text to be tokenized and counted.
model_name (str): The specific Mistral model name to adjust tokenization.
llm_model (BaseChatModel): The specific Mistral model to adjust tokenization.

Returns:
int: The number of tokens in the text.
"""
tokenizer = MistralTokenizer.from_model(model_name)

logger = get_logger()

logger.debug(f"Counting tokens for text of {len(text)} characters")
try:
model = llm_model.model
except AttributeError:
raise NotImplementedError(f"The model provider you are using ('{llm_model}') "
"does not give us a model name so we cannot identify which encoding to use")

tokenizer = MistralTokenizer.from_model(model)

tokenized = tokenizer.encode_chat_completion(
ChatCompletionRequest(
tools=[],
messages=[
UserMessage(content=text),
],
model=model_name,
model=model,
)
)
tokens = tokenized.tokens
Expand Down
31 changes: 31 additions & 0 deletions scrapegraphai/utils/tokenizers/tokenizer_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
Tokenization utilities for Ollama models
"""
from langchain_core.language_models.chat_models import BaseChatModel
from ..logging import get_logger


def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int:
"""
Estimate the number of tokens in a given text using Ollama's tokenization method,
adjusted for different Ollama models.

Args:
text (str): The text to be tokenized and counted.
llm_model (BaseChatModel): The specific Ollama model to adjust tokenization.

Returns:
int: The number of tokens in the text.
"""

logger = get_logger()

logger.debug(f"Counting tokens for text of {len(text)} characters")
try:
model = llm_model.model_name
except AttributeError:
raise NotImplementedError(f"The model provider you are using ('{llm_model}') "
"does not give us a model name so we cannot identify which encoding to use")

raise NotImplementedError(f"Ollama tokenization not implemented yet")

Loading