Skip to content

Commit f55e01b

Browse files
author
bram
committed
Added model from api
1 parent 6a210fe commit f55e01b

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

python_gpt_po/po_translator.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,17 @@ def get_available_models(provider_clients: ProviderClients, provider: ModelProvi
238238
logging.error("Error fetching models from %s: %s", provider.value, str(e))
239239
return []
240240

241+
@staticmethod
242+
def validate_model(provider_clients: ProviderClients, provider: ModelProvider, model: str) -> bool:
243+
"""
244+
Validates whether the specified model is available for the given provider.
245+
Uses prefix matching so that a shorthand (e.g. "claude") will match a full model name.
246+
"""
247+
available_models = ModelManager.get_available_models(provider_clients, provider)
248+
if not available_models:
249+
return False
250+
return any(avail.lower().startswith(model.lower()) for avail in available_models)
251+
241252

242253
class TranslationService:
243254
"""Class to encapsulate translation functionalities."""
@@ -871,17 +882,22 @@ def main():
871882
ModelProvider.DEEPSEEK: "deepseek-chat"
872883
}
873884

874-
# Use specified model or default for the provider
875-
model = args.model or default_models.get(provider)
876-
877-
# Validate the selected model is available
878-
if not model_manager.validate_model(provider_clients, provider, model):
879-
logging.warning(
880-
"Model '%s' not found for provider %s. "
881-
"Using default model %s.",
882-
model, provider.value, default_models.get(provider)
883-
)
884-
model = default_models.get(provider)
885+
if args.model:
886+
model = args.model
887+
if not model_manager.validate_model(provider_clients, provider, model):
888+
logging.warning(
889+
"Model '%s' not found for provider %s. Using default model %s.",
890+
model, provider.value, default_models.get(provider)
891+
)
892+
model = default_models.get(provider)
893+
else:
894+
available_models = model_manager.get_available_models(provider_clients, provider)
895+
if available_models:
896+
model = available_models[0]
897+
logging.info("No model specified; using available model: %s", model)
898+
else:
899+
model = default_models.get(provider)
900+
logging.warning("No available models found from API; defaulting to %s", model)
885901

886902
# Parse language codes and detailed language names
887903
lang_codes = [lang.strip() for lang in args.lang.split(',')]

0 commit comments

Comments
 (0)