Skip to content

Commit 52adbfb

Browse files
authored
Merge pull request #176 from yjg30737/hotfix/all
Fix ModuleNotFoundError issues
2 parents c24f841 + 4d175b2 commit 52adbfb

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

pyqt_openai/util/script.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import PIL.Image
2727
import numpy as np
2828
import psutil
29-
from g4f.gui.server.api import Api
29+
from g4f import ProviderType
30+
from g4f.providers.base_provider import ProviderModelMixin
3031

3132
from pyqt_openai.widgets.scrollableErrorDialog import ScrollableErrorDialog
3233

@@ -796,13 +797,32 @@ def get_g4f_image_models() -> list:
796797
models = [model["image_model"] for model in image_models]
797798
return models
798799

799-
800800
def get_g4f_image_providers(including_auto=False) -> list:
801801
"""
802802
Get all the providers that support image generation
803803
(Even though this is not a perfect way to get the providers that support image generation)
804+
(So i have to bring get_providers function directly from g4f library)
804805
"""
805-
providers = Api.get_providers()
806+
807+
def get_providers():
808+
"""
809+
The function get from g4f/gui/server/api.py
810+
"""
811+
return {
812+
provider.__name__: (provider.label
813+
if hasattr(provider, "label")
814+
else provider.__name__) +
815+
(" (WebDriver)"
816+
if "webdriver" in provider.get_parameters()
817+
else "") +
818+
(" (Auth)"
819+
if provider.needs_auth
820+
else "")
821+
for provider in __providers__
822+
if provider.working
823+
}
824+
825+
providers = get_providers()
806826
if including_auto:
807827
providers = [G4F_PROVIDER_DEFAULT] + [provider for provider in providers]
808828
return providers
@@ -812,10 +832,28 @@ def get_g4f_image_models_from_provider(provider) -> list:
812832
"""
813833
Get all the models that support image generation for a specific provider
814834
(Again, this is not a perfect way to get the models that support image generation)
835+
(So i have to bring get_provider_models function directly from g4f library)
815836
"""
816837
if provider == G4F_PROVIDER_DEFAULT:
817838
return get_g4f_image_models()
818-
return [model["model"] for model in Api.get_provider_models(provider)]
839+
840+
def get_provider_models(provider: str) -> list[dict]:
841+
"""
842+
From g4f/gui/server/api.py
843+
"""
844+
if provider in __map__:
845+
provider: ProviderType = __map__[provider]
846+
if issubclass(provider, ProviderModelMixin):
847+
return [{"model": model, "default": model == provider.default_model} for model in provider.get_models()]
848+
elif provider.supports_gpt_35_turbo or provider.supports_gpt_4:
849+
return [
850+
*([{"model": "gpt-4", "default": not provider.supports_gpt_4}] if provider.supports_gpt_4 else []),
851+
*([{"model": "gpt-3.5-turbo",
852+
"default": not provider.supports_gpt_4}] if provider.supports_gpt_35_turbo else [])
853+
]
854+
else:
855+
return []
856+
return [model["model"] for model in get_provider_models(provider)]
819857

820858

821859
def get_g4f_argument(model, messages, cur_text, stream):

0 commit comments

Comments
 (0)