From d291819be3adaf83e8d4e778820438e1516b1bf6 Mon Sep 17 00:00:00 2001 From: yusefes Date: Thu, 17 Oct 2024 16:34:13 +0330 Subject: [PATCH] Fix tokenizer loading for GPT2 Fixes #752 Fix the issue with loading the tokenizer for 'gpt2'. * **scrapegraphai/utils/tokenizer.py** - Add a check for `GPT2TokenizerFast` in the `num_tokens_calculus` function. - Import `GPT2TokenizerFast` from `transformers`. * **scrapegraphai/utils/tokenizers/tokenizer_ollama.py** - Modify the `num_tokens_ollama` function to handle `GPT2TokenizerFast`. * **tests/graphs/smart_scraper_ollama_test.py** - Add a test case to verify the tokenizer loading for `GPT2TokenizerFast`. --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ScrapeGraphAI/Scrapegraph-ai/issues/752?shareId=XXXX-XXXX-XXXX-XXXX). --- scrapegraphai/utils/tokenizer.py | 8 ++++++++ scrapegraphai/utils/tokenizers/tokenizer_ollama.py | 7 ++++++- tests/graphs/smart_scraper_ollama_test.py | 9 +++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/scrapegraphai/utils/tokenizer.py b/scrapegraphai/utils/tokenizer.py index f6650672..8d5577fd 100644 --- a/scrapegraphai/utils/tokenizer.py +++ b/scrapegraphai/utils/tokenizer.py @@ -6,6 +6,7 @@ from langchain_ollama import ChatOllama from langchain_mistralai import ChatMistralAI from langchain_core.language_models.chat_models import BaseChatModel +from transformers import GPT2TokenizerFast def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int: """ @@ -23,6 +24,13 @@ def num_tokens_calculus(string: str, llm_model: BaseChatModel) -> int: from .tokenizers.tokenizer_ollama import num_tokens_ollama num_tokens_fn = num_tokens_ollama + elif isinstance(llm_model, GPT2TokenizerFast): + def num_tokens_gpt2(text: str, model: BaseChatModel) -> int: + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + tokens = tokenizer.encode(text) + return len(tokens) + num_tokens_fn = num_tokens_gpt2 + else: from .tokenizers.tokenizer_openai import num_tokens_openai num_tokens_fn = num_tokens_openai diff --git a/scrapegraphai/utils/tokenizers/tokenizer_ollama.py b/scrapegraphai/utils/tokenizers/tokenizer_ollama.py index a981e25c..feb59e6b 100644 --- a/scrapegraphai/utils/tokenizers/tokenizer_ollama.py +++ b/scrapegraphai/utils/tokenizers/tokenizer_ollama.py @@ -3,6 +3,7 @@ """ from langchain_core.language_models.chat_models import BaseChatModel from ..logging import get_logger +from transformers import GPT2TokenizerFast def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int: """ @@ -21,8 +22,12 @@ def num_tokens_ollama(text: str, llm_model:BaseChatModel) -> int: logger.debug(f"Counting tokens for text of {len(text)} characters") + if isinstance(llm_model, GPT2TokenizerFast): + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + tokens = tokenizer.encode(text) + return len(tokens) + # Use langchain token count implementation # NB: https://github.com/ollama/ollama/issues/1716#issuecomment-2074265507 tokens = llm_model.get_num_tokens(text) return tokens - diff --git a/tests/graphs/smart_scraper_ollama_test.py b/tests/graphs/smart_scraper_ollama_test.py index a358feb6..179acc36 100644 --- a/tests/graphs/smart_scraper_ollama_test.py +++ b/tests/graphs/smart_scraper_ollama_test.py @@ -3,6 +3,7 @@ """ import pytest from scrapegraphai.graphs import SmartScraperGraph +from transformers import GPT2TokenizerFast @pytest.fixture @@ -50,3 +51,11 @@ def test_get_execution_info(graph_config: dict): graph_exec_info = smart_scraper_graph.get_execution_info() assert graph_exec_info is not None + + +def test_gpt2_tokenizer_loading(): + """ + Test loading of GPT2TokenizerFast + """ + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + assert tokenizer is not None