diff --git a/scrapegraphai/utils/tokenizer.py b/scrapegraphai/utils/tokenizer.py index f6650672..8d5577fd 100644 --- a/scrapegraphai/utils/tokenizer.py +++ b/scrapegraphai/utils/tokenizer.py @@ -6,6 +6,7 @@ from langchain_ollama import ChatOllama from langchain_mistralai import ChatMistralAI from langchain_core.language_models.chat_models import BaseChatModel +from transformers import GPT2TokenizerFast def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int: """ @@ -23,6 +24,13 @@ def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int: from .tokenizers.tokenizer_ollama import num_tokens_ollama num_tokens_fn = num_tokens_ollama + elif isinstance(llm_model, GPT2TokenizerFast): + def num_tokens_gpt2(text: str, model: BaseChatModel) -> int: + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + tokens = tokenizer.encode(text) + return len(tokens) + num_tokens_fn = num_tokens_gpt2 + else: from .tokenizers.tokenizer_openai import num_tokens_openai num_tokens_fn = num_tokens_openai diff --git a/scrapegraphai/utils/tokenizers/tokenizer_ollama.py b/scrapegraphai/utils/tokenizers/tokenizer_ollama.py index a981e25c..feb59e6b 100644 --- a/scrapegraphai/utils/tokenizers/tokenizer_ollama.py +++ b/scrapegraphai/utils/tokenizers/tokenizer_ollama.py @@ -3,6 +3,7 @@ """ from langchain_core.language_models.chat_models import BaseChatModel from ..logging import get_logger +from transformers import GPT2TokenizerFast def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int: """ @@ -21,8 +22,12 @@ def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int: logger.debug(f"Counting tokens for text of {len(text)} characters") + if isinstance(llm_model, GPT2TokenizerFast): + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + tokens = tokenizer.encode(text) + return len(tokens) + # Use langchain token count implementation # NB: https://github.com/ollama/ollama/issues/1716#issuecomment-2074265507 tokens = llm_model.get_num_tokens(text) return tokens - diff --git a/tests/graphs/smart_scraper_ollama_test.py b/tests/graphs/smart_scraper_ollama_test.py index a358feb6..179acc36 100644 --- a/tests/graphs/smart_scraper_ollama_test.py +++ b/tests/graphs/smart_scraper_ollama_test.py @@ -3,6 +3,7 @@ """ import pytest from scrapegraphai.graphs import SmartScraperGraph +from transformers import GPT2TokenizerFast @pytest.fixture @@ -50,3 +51,11 @@ def test_get_execution_info(graph_config: dict): graph_exec_info = smart_scraper_graph.get_execution_info() assert graph_exec_info is not None + + +def test_gpt2_tokenizer_loading(): + """ + Test loading of GPT2TokenizerFast + """ + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + assert tokenizer is not None