Skip to content

Commit 72c1a7d

Browse files
yanbasic11zhouxuan
andauthored
feat: add OpenAI compatible API router
* feat: add OpenAI compatible API router * fix: resolve conflicts * fix: resolve conflicts * fix: trailing whitespace --------- Co-authored-by: zhouxss <34160552+11zhouxuan@users.noreply.github.com>
1 parent 257323f commit 72c1a7d

File tree

16 files changed

+1352
-153
lines changed

16 files changed

+1352
-153
lines changed

src/emd/cfn/codepipeline/template.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ Resources:
6363
- lambda:*
6464
- logs:*
6565
- elasticloadbalancing:*
66+
- application-autoscaling:*
6667
Resource:
6768
- "*"
6869
ManagedPolicyArns:
@@ -236,6 +237,7 @@ Resources:
236237
cd ..
237238
cp cfn/$ServiceType/template.yaml template.yaml
238239
cp pipeline/parameters.json parameters.json
240+
python cfn/shared/filter_parameters.py template.yaml parameters.json
239241
cat parameters.json
240242
echo post build completed on `date`
241243

src/emd/cfn/ecs/post_build.py

Lines changed: 12 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -5,138 +5,8 @@
55
import argparse
66
import sys
77
from emd.models.utils.serialize_utils import load_extra_params
8-
9-
# Post build script for ECS, it will deploy the VPC and ECS cluster.
10-
CFN_ROOT_PATH = "../cfn"
11-
WAIT_SECONDS = 10
12-
13-
14-
def wait_for_stack_completion(client, stack_name):
15-
while True:
16-
response = client.describe_stacks(StackName=stack_name)
17-
stack_status = response["Stacks"][0]["StackStatus"]
18-
while stack_status.endswith("IN_PROGRESS"):
19-
print(
20-
f"Stack {stack_name} is currently {stack_status}. Waiting for completion..."
21-
)
22-
time.sleep(WAIT_SECONDS)
23-
response = client.describe_stacks(StackName=stack_name)
24-
stack_status = response["Stacks"][0]["StackStatus"]
25-
26-
if stack_status in ["CREATE_COMPLETE", "UPDATE_COMPLETE"]:
27-
print(f"Stack {stack_name} deployment complete")
28-
break
29-
else:
30-
print(
31-
f"Post build stage failed. The stack {stack_name} is in an unexpected status: {stack_status}. Please visit the AWS CloudFormation Console to delete the stack."
32-
)
33-
sys.exit(1)
34-
35-
def get_stack_outputs(client, stack_name):
36-
response = client.describe_stacks(StackName=stack_name)
37-
return response["Stacks"][0].get("Outputs", [])
38-
39-
40-
def create_or_update_stack(client, stack_name, template_path, parameters=[]):
41-
try:
42-
wait_for_stack_completion(client, stack_name)
43-
response = client.describe_stacks(StackName=stack_name)
44-
stack_status = response["Stacks"][0]["StackStatus"]
45-
46-
if stack_status in ["CREATE_COMPLETE", "UPDATE_COMPLETE"]:
47-
print(f"Stack {stack_name} already exists. Proceeding with update.")
48-
with open(template_path, "r") as template_file:
49-
template_body = template_file.read()
50-
51-
try:
52-
response = client.update_stack(
53-
StackName=stack_name,
54-
TemplateBody=template_body,
55-
Capabilities=["CAPABILITY_NAMED_IAM"],
56-
Parameters=parameters
57-
)
58-
except Exception as e:
59-
print(f"No updates are to be performed for stack {stack_name}.")
60-
61-
print(f"Started update of stack {stack_name}")
62-
wait_for_stack_completion(client, stack_name)
63-
64-
except client.exceptions.ClientError as e:
65-
if "does not exist" in str(e):
66-
print(f"Stack {stack_name} does not exist. Proceeding with creation.")
67-
with open(template_path, "r") as template_file:
68-
template_body = template_file.read()
69-
70-
response = client.create_stack(
71-
StackName=stack_name,
72-
TemplateBody=template_body,
73-
Capabilities=["CAPABILITY_NAMED_IAM"],
74-
Parameters=parameters,
75-
EnableTerminationProtection=True,
76-
)
77-
78-
stack_id = response["StackId"]
79-
print(f"Started deployment of stack {stack_name} with ID {stack_id}")
80-
wait_for_stack_completion(client, stack_name)
81-
else:
82-
print(
83-
f"Post build stage failed. The stack {stack_name} is in an unexpected status: {stack_status}. Please visit the AWS CloudFormation Console to delete the stack."
84-
)
85-
sys.exit(1)
86-
87-
88-
def update_parameters_file(parameters_path, updates):
89-
with open(parameters_path, "r") as file:
90-
data = json.load(file)
91-
92-
data["Parameters"].update(updates)
93-
94-
with open(parameters_path, "w") as file:
95-
json.dump(data, file, indent=4)
96-
97-
98-
def deploy_vpc_template(region):
99-
client = boto3.client("cloudformation", region_name=region)
100-
stack_name = "EMD-VPC"
101-
template_path = f"{CFN_ROOT_PATH}/vpc/template.yaml"
102-
create_or_update_stack(client, stack_name, template_path)
103-
outputs = get_stack_outputs(client, stack_name)
104-
vpc_id = None
105-
subnets = None
106-
for output in outputs:
107-
if output["OutputKey"] == "VPCID" and output["OutputValue"]:
108-
vpc_id = output["OutputValue"]
109-
elif output["OutputKey"] == "Subnets" and output["OutputValue"]:
110-
subnets = output["OutputValue"]
111-
update_parameters_file("parameters.json", {"VPCID": vpc_id, "Subnets": subnets})
112-
return vpc_id, subnets
113-
114-
115-
def deploy_ecs_cluster_template(region, vpc_id, subnets):
116-
client = boto3.client("cloudformation", region_name=region)
117-
stack_name = "EMD-ECS-Cluster"
118-
template_path = f"{CFN_ROOT_PATH}/ecs/cluster.yaml"
119-
create_or_update_stack(
120-
client,
121-
stack_name,
122-
template_path,
123-
[
124-
{
125-
"ParameterKey": "VPCID",
126-
"ParameterValue": vpc_id,
127-
},
128-
{
129-
"ParameterKey": "Subnets",
130-
"ParameterValue": subnets,
131-
},
132-
],
133-
)
134-
135-
outputs = get_stack_outputs(client, stack_name)
136-
for output in outputs:
137-
update_parameters_file(
138-
"parameters.json", {output["OutputKey"]: output["OutputValue"]}
139-
)
8+
from emd.cfn.shared.ecs_cluster import deploy_ecs_cluster, remove_parameters_file
9+
# Post build script for SageMaker OpenAI Compatible Interface, it will deploy the VPC and ECS cluster with an API router Fargate ECS service.
14010

14111

14212
def post_build():
@@ -160,14 +30,19 @@ def post_build():
16030

16131
service_params = args.extra_params.get("service_params", {})
16232

163-
if "vpc_id" not in service_params:
164-
vpc_id, subnets = deploy_vpc_template(args.region)
165-
else:
33+
if "vpc_id" in service_params:
16634
vpc_id = service_params.get("vpc_id")
16735
subnets = service_params.get("subnet_ids")
168-
update_parameters_file("parameters.json", {"VPCID": vpc_id, "Subnets": subnets})
36+
else:
37+
vpc_id = None
38+
subnets = None
39+
40+
if "use_spot" in service_params and service_params.get("use_spot") == "true":
41+
use_spot = True
42+
else:
43+
use_spot = False
16944

170-
deploy_ecs_cluster_template(args.region, vpc_id, subnets)
45+
deploy_ecs_cluster(args.region, vpc_id, subnets, use_spot)
17146

17247

17348
if __name__ == "__main__":

src/emd/cfn/ecs/template.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,9 @@ Resources:
361361
TargetGroupArn: !Ref ServiceTargetGroup
362362

363363
Outputs:
364+
Model:
365+
Description: Model ID used to generate the response
366+
Value: !Join ['', [!Ref ModelId, '/', !Ref ModelTag]]
364367
PublicLoadBalancerDNSName:
365368
Description: The DNS name of the public load balancer. To use HTTPS, create an SSL certificate in AWS Certificate Manager and attach it to the load balancer.
366369
Value: !Join ['', ['http://', !Ref DNSName, '/', !Ref ModelId, '/', !Ref ModelTag]]
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import boto3
2+
import time
3+
import json
4+
import os
5+
import argparse
6+
import sys
7+
from emd.models.utils.serialize_utils import load_extra_params
8+
from emd.cfn.shared.ecs_cluster import deploy_ecs_cluster, remove_parameters_file
9+
# Post build script for SageMaker OpenAI Compatible Interface, it will deploy the VPC and ECS cluster with an API router Fargate ECS service.
10+
11+
12+
def post_build():
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument("--region", type=str, required=False)
15+
parser.add_argument("--model_id", type=str, required=False)
16+
parser.add_argument("--model_tag", type=str, required=False)
17+
parser.add_argument("--framework_type", type=str, required=False)
18+
parser.add_argument("--service_type", type=str, required=False)
19+
parser.add_argument("--backend_type", type=str, required=False)
20+
parser.add_argument("--model_s3_bucket", type=str, required=False)
21+
parser.add_argument("--instance_type", type=str, required=False)
22+
parser.add_argument(
23+
"--extra_params",
24+
type=load_extra_params,
25+
required=False,
26+
default=os.environ.get("extra_params", "{}"),
27+
)
28+
29+
args = parser.parse_args()
30+
31+
service_params = args.extra_params.get("service_params", {})
32+
33+
if "vpc_id" in service_params:
34+
vpc_id = service_params.get("vpc_id")
35+
subnets = service_params.get("subnet_ids")
36+
else:
37+
vpc_id = None
38+
subnets = None
39+
40+
if "use_spot" in service_params and service_params.get("use_spot") == "true":
41+
use_spot = True
42+
else:
43+
use_spot = False
44+
45+
deploy_ecs_cluster(args.region, vpc_id, subnets, use_spot)
46+
47+
48+
if __name__ == "__main__":
49+
post_build()

src/emd/cfn/sagemaker_async/template.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ Resources:
9393
EndpointConfigName: !GetAtt SageMakerEndpointConfig.EndpointConfigName
9494

9595
Outputs:
96-
ModelId:
97-
Description: The emd model ID to be used for the SageMaker Endpoint
98-
Value: !Ref ModelId
96+
Model:
97+
Description: Model ID used to generate the response
98+
Value: !Join ['', [!Ref ModelId, '/', !Ref ModelTag]]
9999
SageMakerEndpointName:
100100
Description: The name of the SageMaker Endpoint
101101
Value: !GetAtt SageMakerEndpoint.EndpointName
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import boto3
2+
import time
3+
import json
4+
import os
5+
import argparse
6+
import sys
7+
from emd.models.utils.serialize_utils import load_extra_params
8+
from emd.cfn.shared.ecs_cluster import deploy_ecs_cluster, remove_parameters_file
9+
# Post build script for SageMaker OpenAI Compatible Interface, it will deploy the VPC and ECS cluster with an API router Fargate ECS service.
10+
11+
12+
def post_build():
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument("--region", type=str, required=False)
15+
parser.add_argument("--model_id", type=str, required=False)
16+
parser.add_argument("--model_tag", type=str, required=False)
17+
parser.add_argument("--framework_type", type=str, required=False)
18+
parser.add_argument("--service_type", type=str, required=False)
19+
parser.add_argument("--backend_type", type=str, required=False)
20+
parser.add_argument("--model_s3_bucket", type=str, required=False)
21+
parser.add_argument("--instance_type", type=str, required=False)
22+
parser.add_argument(
23+
"--extra_params",
24+
type=load_extra_params,
25+
required=False,
26+
default=os.environ.get("extra_params", "{}"),
27+
)
28+
29+
args = parser.parse_args()
30+
31+
service_params = args.extra_params.get("service_params", {})
32+
33+
if "vpc_id" in service_params:
34+
vpc_id = service_params.get("vpc_id")
35+
subnets = service_params.get("subnet_ids")
36+
else:
37+
vpc_id = None
38+
subnets = None
39+
40+
if "use_spot" in service_params and service_params.get("use_spot") == "true":
41+
use_spot = True
42+
else:
43+
use_spot = False
44+
45+
deploy_ecs_cluster(args.region, vpc_id, subnets, use_spot)
46+
47+
48+
if __name__ == "__main__":
49+
post_build()

src/emd/cfn/sagemaker_realtime/template.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ Parameters:
3737
SageMakerEndpointName:
3838
Type: String
3939
Description: The name of the SageMaker Endpoint
40-
Default: "noname"
40+
Default: "Auto-generate"
4141

4242
Conditions:
43-
UseDefaultEndpointName: !Equals [!Ref SageMakerEndpointName, "noname"]
43+
UseDefaultEndpointName: !Equals [!Ref SageMakerEndpointName, "Auto-generate"]
4444

4545
Resources:
4646
ExecutionRole:
@@ -148,9 +148,9 @@ Resources:
148148
ScaleOutCooldown: 600
149149

150150
Outputs:
151-
ModelId:
152-
Description: The emd model ID to be used for the SageMaker Endpoint
153-
Value: !Ref ModelId
151+
Model:
152+
Description: Model ID used to generate the response
153+
Value: !Join ['', [!Ref ModelId, '/', !Ref ModelTag]]
154154
SageMakerEndpointName:
155155
Description: The name of the SageMaker Endpoint
156156
Value: !GetAtt SageMakerEndpoint.EndpointName

0 commit comments

Comments
 (0)