From ade28fca2c3fdf40f28a80854e3b8435a52a6930 Mon Sep 17 00:00:00 2001 From: Federico Aguzzi <62149513+f-aguzzi@users.noreply.github.com> Date: Fri, 2 Aug 2024 12:15:43 +0200 Subject: [PATCH] fix(AbstractGraph): instantiation of Azure GPT models Closes #498 --- requirements-dev.lock | 1 + requirements.lock | 1 + requirements.txt | 1 + scrapegraphai/graphs/abstract_graph.py | 8 ++++---- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/requirements-dev.lock b/requirements-dev.lock index 24b7156d..d14f9d42 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -185,6 +185,7 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright + # via sqlalchemy groq==0.9.0 # via langchain-groq grpc-google-iam-v1==0.13.1 diff --git a/requirements.lock b/requirements.lock index 0e8bb930..7dbac1f3 100644 --- a/requirements.lock +++ b/requirements.lock @@ -133,6 +133,7 @@ graphviz==0.20.3 # via scrapegraphai greenlet==3.0.3 # via playwright + # via sqlalchemy groq==0.9.0 # via langchain-groq grpc-google-iam-v1==0.13.1 diff --git a/requirements.txt b/requirements.txt index 8f3f5da5..9c11363c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,3 +23,4 @@ semchunk>=1.0.1 langchain-fireworks>=0.1.3 langchain-community>=0.2.9 langchain-huggingface>=0.0.3 +browserbase==0.3.0 diff --git a/scrapegraphai/graphs/abstract_graph.py b/scrapegraphai/graphs/abstract_graph.py index a7493351..f07bcb10 100644 --- a/scrapegraphai/graphs/abstract_graph.py +++ b/scrapegraphai/graphs/abstract_graph.py @@ -146,6 +146,10 @@ def handle_model(model_name, provider, token_key, default_token=8192): llm_params["model"] = model_name return init_chat_model(**llm_params) + if "azure" in llm_params["model"]: + model_name = llm_params["model"].split("/")[-1] + return handle_model(model_name, "azure_openai", model_name) + if "gpt-" in llm_params["model"]: return handle_model(llm_params["model"], "openai", llm_params["model"]) @@ -154,10 +158,6 @@ def handle_model(model_name, provider, token_key, default_token=8192): token_key = llm_params["model"].split("/")[-1] return handle_model(model_name, "fireworks", token_key) - if "azure" in llm_params["model"]: - model_name = llm_params["model"].split("/")[-1] - return handle_model(model_name, "azure_openai", model_name) - if "gemini" in llm_params["model"]: model_name = llm_params["model"].split("/")[-1] return handle_model(model_name, "google_genai", model_name)