Skip to content

Commit 81bb59d

Browse files
committed
Updates to score testing logic
1 parent 472691b commit 81bb59d

File tree

3 files changed

+18
-10
lines changed

3 files changed

+18
-10
lines changed

src/sasctl/_services/score_definitions.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,15 +83,20 @@ def create_score_definition(
8383
else:
8484
object_descriptor_type = "sas.models.model.ds2"
8585

86-
model = cls._model_repository.get_model(model)
86+
if cls._model_repository.is_uuid(model):
87+
model_id = model
88+
elif isinstance(model, dict) and "id" in model:
89+
model_id = model["id"]
90+
else:
91+
model = cls._model_repository.get_model(model)
92+
model_id = model["id"]
8793

8894
if not model:
8995
raise HTTPError(
9096
{
9197
f"This model may not exist in a project or the model may not exist at all."
9298
}
9399
)
94-
model_id = model.id
95100
model_project_id = model.get("projectId")
96101
model_project_version_id = model.get("projectVersionId")
97102
model_name = model.get("name")

src/sasctl/_services/score_execution.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def poll_score_execution_state(
149149
def get_score_execution_results(
150150
cls,
151151
score_execution: Union[dict, str],
152+
use_cas_gateway: False
152153
):
153154
"""Generates an output table for the score_execution results.
154155
@@ -183,13 +184,13 @@ def get_score_execution_results(
183184
else:
184185
session = current_session()
185186
cas = session.as_swat()
186-
response = cas.loadActionSet("gateway")
187-
if not response:
187+
if not use_cas_gateway:
188188
output_table = cls._no_gateway_get_results(
189189
server_name, library_name, table_name
190190
)
191191
return output_table
192192
else:
193+
cas.loadActionSet("gateway")
193194
gateway_code = f"""
194195
import pandas as pd
195196
import numpy as np
@@ -235,12 +236,14 @@ def _no_gateway_get_results(cls, server_name, library_name, table_name):
235236
f"caslibs/{library_name}/"
236237
f"tables/{table_name}/columns?limit=10000"
237238
)
238-
columns = json_normalize(output_columns.json(), "items")
239-
column_names = columns["names"].to_list()
239+
columns = json_normalize(output_columns)
240+
column_names = columns["name"].to_list()
240241

241-
output_rows = cls._services.get(
242-
f"casRowSets/servers/{server_name}"
243-
f"caslibs/{library_name}"
242+
session = current_session()
243+
244+
output_rows = session.get(
245+
f"casRowSets/servers/{server_name}/"
246+
f"caslibs/{library_name}/"
244247
f"tables/{table_name}/rows?limit=10000"
245248
)
246249
output_table = pd.DataFrame(

src/sasctl/tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -996,5 +996,5 @@ def score_model_with_cas(
996996
score_execution = se.create_score_execution(score_definition.id)
997997
score_execution_poll = se.poll_score_execution_state(score_execution)
998998
print(score_execution_poll)
999-
score_results = se.get_score_execution_results(score_execution)
999+
score_results = se.get_score_execution_results(score_execution, use_cas_gateway)
10001000
return score_results

0 commit comments

Comments
 (0)