Skip to content

Commit 5aefb81

Browse files
Merge pull request #2448 from solliancenet/cj-fix-bedrock-timeout
Fix issue with Bedrock timeout
2 parents a1236bb + 31ab78b commit 5aefb81

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

src/python/PythonSDK/foundationallm/langchain/language_models/language_model_factory.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import boto3
2+
import botocore
23
import json
34
from azure.identity import DefaultAzureCredential, get_bearer_token_provider
45
from google.oauth2 import service_account
@@ -17,11 +18,11 @@
1718
from foundationallm.utils import ObjectUtils
1819

1920
class LanguageModelFactory:
20-
21+
2122
def __init__(self, objects:dict, config: Configuration):
2223
self.objects = objects
2324
self.config = config
24-
25+
2526
def get_language_model(self,
2627
ai_model_object_id:str,
2728
override_operation_type: OperationTypes = None
@@ -43,24 +44,24 @@ def get_language_model(self,
4344
if ai_model is None:
4445
raise LangChainException("AI model configuration settings are missing.", 400)
4546

46-
api_endpoint = ObjectUtils.get_object_by_id(ai_model.endpoint_object_id, self.objects, APIEndpointConfiguration)
47+
api_endpoint = ObjectUtils.get_object_by_id(ai_model.endpoint_object_id, self.objects, APIEndpointConfiguration)
4748
if api_endpoint is None:
4849
raise LangChainException("API endpoint configuration settings are missing.", 400)
49-
50+
5051
match api_endpoint.provider:
5152
case LanguageModelProvider.MICROSOFT:
5253
op_type = api_endpoint.operation_type
5354
if override_operation_type is not None:
5455
op_type = override_operation_type
55-
if api_endpoint.authentication_type == AuthenticationTypes.AZURE_IDENTITY:
56+
if api_endpoint.authentication_type == AuthenticationTypes.AZURE_IDENTITY:
5657
try:
5758
scope = api_endpoint.authentication_parameters.get('scope', 'https://cognitiveservices.azure.com/.default')
5859
# Set up a Azure AD token provider.
5960
token_provider = get_bearer_token_provider(
6061
DefaultAzureCredential(exclude_environment_credential=True),
6162
scope
6263
)
63-
64+
6465
if op_type == OperationTypes.CHAT:
6566
language_model = AzureChatOpenAI(
6667
azure_endpoint=api_endpoint.url,
@@ -69,27 +70,27 @@ def get_language_model(self,
6970
azure_ad_token_provider=token_provider,
7071
azure_deployment=ai_model.deployment_name
7172
)
72-
elif op_type == OperationTypes.ASSISTANTS_API or op_type == OperationTypes.IMAGE_SERVICES:
73+
elif op_type == OperationTypes.ASSISTANTS_API or op_type == OperationTypes.IMAGE_SERVICES:
7374
# Assistants API clients can't have deployment as that is assigned at the assistant level.
7475
language_model = async_aoi(
7576
azure_endpoint=api_endpoint.url,
76-
api_version=api_endpoint.api_version,
77+
api_version=api_endpoint.api_version,
7778
azure_ad_token_provider=token_provider
7879
)
7980
else:
8081
raise LangChainException(f"Unsupported operation type: {op_type}", 400)
8182

8283
except Exception as e:
8384
raise LangChainException(f"Failed to create Azure OpenAI API connector: {str(e)}", 500)
84-
else: # Key-based authentication
85-
try:
86-
api_key = self.config.get_value(api_endpoint.authentication_parameters.get('api_key_configuration_name'))
85+
else: # Key-based authentication
86+
try:
87+
api_key = self.config.get_value(api_endpoint.authentication_parameters.get('api_key_configuration_name'))
8788
except Exception as e:
8889
raise LangChainException(f"Failed to retrieve API key: {str(e)}", 500)
8990

9091
if api_key is None:
9192
raise LangChainException("API key is missing from the configuration settings.", 400)
92-
93+
9394
if op_type == OperationTypes.CHAT:
9495
language_model = AzureChatOpenAI(
9596
azure_endpoint=api_endpoint.url,
@@ -121,6 +122,8 @@ def get_language_model(self,
121122
else OpenAI(base_url=api_endpoint.url, api_key=api_key)
122123
)
123124
case LanguageModelProvider.BEDROCK:
125+
boto3_config = botocore.config.Config(connect_timeout=60, read_timeout=api_endpoint.timeout_seconds)
126+
124127
if api_endpoint.authentication_type == AuthenticationTypes.AZURE_IDENTITY:
125128
# Get Azure scope for federated authentication as well as the AWS role ARN (Amazon Resource Name).
126129
try:
@@ -159,7 +162,8 @@ def get_language_model(self,
159162
region_name = region,
160163
aws_access_key_id = creds["AccessKeyId"],
161164
aws_secret_access_key = creds["SecretAccessKey"],
162-
aws_session_token= creds["SessionToken"]
165+
aws_session_token= creds["SessionToken"],
166+
config=boto3_config
163167
)
164168
else: # Key-based authentication
165169
try:
@@ -184,10 +188,11 @@ def get_language_model(self,
184188
model= ai_model.deployment_name,
185189
region_name = region,
186190
aws_access_key_id = access_key,
187-
aws_secret_access_key = secret_key
191+
aws_secret_access_key = secret_key,
192+
config=boto3_config
188193
)
189194
case LanguageModelProvider.VERTEXAI:
190-
# Only supports service account authentication via JSON credentials stored in key vault.
195+
# Only supports service account authentication via JSON credentials stored in key vault.
191196
# Uses the authentication parameter: service_account_credentials to get the application configuration key for this value.
192197
try:
193198
service_account_credentials_definition = json.loads(self.config.get_value(api_endpoint.authentication_parameters.get('service_account_credentials')))
@@ -197,8 +202,8 @@ def get_language_model(self,
197202
if not service_account_credentials_definition:
198203
raise LangChainException("Service account credentials are missing from the configuration settings.", 400)
199204

200-
service_account_credentials = service_account.Credentials.from_service_account_info(service_account_credentials_definition)
201-
language_model = ChatVertexAI(
205+
service_account_credentials = service_account.Credentials.from_service_account_info(service_account_credentials_definition)
206+
language_model = ChatVertexAI(
202207
model=ai_model.deployment_name,
203208
temperature=0,
204209
max_tokens=None,
@@ -211,5 +216,5 @@ def get_language_model(self,
211216
for key, value in ai_model.model_parameters.items():
212217
if hasattr(language_model, key):
213218
setattr(language_model, key, value)
214-
219+
215220
return language_model

0 commit comments

Comments
 (0)