Skip to content

Commit b7210ab

Browse files
pushing correct version
1 parent 7b23992 commit b7210ab

File tree

2 files changed

+162
-27
lines changed

2 files changed

+162
-27
lines changed

src/sasctl/_services/score_definitions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
class ScoreDefinitions(Service):
16+
1617
"""
1718
Used for creating and maintaining score definitions.
1819

src/sasctl/_services/score_execution.py

Lines changed: 161 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
import json
2+
import time
3+
import warnings
4+
from distutils.version import StrictVersion
5+
from typing import Union
26

7+
import pandas as pd
38
from requests import HTTPError
49

10+
from .cas_management import CASManagement
11+
from ..core import current_session
512
from .score_definitions import ScoreDefinitions
613
from .service import Service
714

815

916
class ScoreExecution(Service):
17+
1018
"""
1119
The Score Execution API is used to produce a score by
1220
executing the mapped code generated by score objects using the score definition.
@@ -18,7 +26,9 @@ class ScoreExecution(Service):
1826
"""
1927

2028
_SERVICE_ROOT = "/scoreExecution"
29+
_cas_management = CASManagement()
2130
_score_definitions = ScoreDefinitions()
31+
_services = Service()
2232

2333
(
2434
list_executions,
@@ -57,39 +67,21 @@ def create_score_execution(
5767
5868
"""
5969

60-
# Gets information about the scoring object from the score definition and raises an exception if the score definition does not exist
70+
# Gets information about the scoring object from the score definition
6171
score_definition = cls._score_definitions.get_definition(score_definition_id)
6272
if not score_definition:
6373
raise HTTPError
6474
score_exec_name = score_definition.get("name")
65-
model_uri = score_definition.get("objectDescriptor", "uri")
66-
model_name = score_definition.get("objectDescriptor", "name")
67-
model_input_library = score_definition.get("inputData", "libraryName")
68-
model_table_name = score_definition.get("inputData", "tableName")
75+
model_uuid = score_definition.get("objectDescriptor").get("uri").split("/")[-1]
76+
model_uri = f"/modelManagement/models/{model_uuid}"
77+
model_name = score_definition.get("objectDescriptor").get("name")
78+
model_input_library = score_definition.get("inputData").get("libraryName")
79+
model_table_name = score_definition.get("inputData").get("tableName")
6980

7081
# Defining a default output table name if none is provided
7182
if not output_table_name:
7283
output_table_name = f"{model_name}_{score_definition_id}"
7384

74-
# Getting all score executions that are using the inputted score_definition_id
75-
76-
# score_execution = cls.list_executions(
77-
# filter=f"eq(scoreDefinitionId, '{score_definition_id}')"
78-
# )
79-
score_execution = cls.get("scoreExecution/executions",
80-
filter=f"filter=eq(scoreExecutionRequest.scoreDefinitionId,%{score_definition_id}%27)"
81-
)
82-
if not score_execution:
83-
raise HTTPError(f"Something went wrong in the LIST_EXECUTIONS statement.")
84-
85-
# Checking the count of the execution list to see if there are any score executions for this score_definition_id already running
86-
execution_count = score_execution.get("count") # Exception catch location
87-
if execution_count == 1:
88-
execution_id = score_execution.get("items", 0, "id")
89-
deleted_execution = cls.delete_execution(execution_id)
90-
if deleted_execution.status_code >= 400:
91-
raise HTTPError(f"Something went wrong in the DELETE statement.")
92-
9385
headers_score_exec = {"Content-Type": "application/json"}
9486

9587
create_score_exec = {
@@ -109,9 +101,151 @@ def create_score_execution(
109101
}
110102

111103
# Creating the score execution
112-
new_score_execution = cls.post(
113-
"scoreExecution/executions",
104+
score_execution = cls.post(
105+
"executions",
114106
data=json.dumps(create_score_exec),
115107
headers=headers_score_exec,
116108
)
117-
return new_score_execution
109+
110+
return score_execution
111+
112+
@classmethod
113+
def poll_score_execution_state(
114+
cls, score_execution: Union[dict, str], timeout: int = 300
115+
):
116+
"""Checks the state of the score execution.
117+
118+
Parameters
119+
--------
120+
score_execution: str or dict
121+
A running score_execution.
122+
timeout: int
123+
Time limit for checking the score_execution state.
124+
125+
Returns
126+
-------
127+
String
128+
129+
"""
130+
if type(score_execution) is str:
131+
exec_id = score_execution
132+
else:
133+
exec_id = score_execution.get("id")
134+
135+
start_poll = time.time()
136+
while time.time() - start_poll < timeout:
137+
score_execution_state = cls.get(f"executions/{exec_id}/state")
138+
if score_execution_state == "completed":
139+
print("Score execution state is 'completed'")
140+
return "completed"
141+
elif score_execution_state == "failed":
142+
# TODO: Grab score execution logs and return those
143+
print("The score execution state is failed.")
144+
return "failed"
145+
elif time.time() - start_poll > timeout:
146+
print("The score execution is still running, but polling time ran out.")
147+
return "timeout"
148+
149+
@classmethod
150+
def get_score_execution_results(
151+
cls,
152+
score_execution: Union[dict, str],
153+
):
154+
"""Generates an output table for the score_execution results.
155+
156+
Parameters
157+
--------
158+
score_execution: str or dict
159+
A running score_execution.
160+
161+
Returns
162+
-------
163+
Table reference
164+
165+
"""
166+
try:
167+
import swat
168+
except ImportError:
169+
swat = None
170+
171+
if type(score_execution) is str:
172+
score_execution = cls.get_execution(score_execution)
173+
174+
server_name = score_execution.get("outputTable").get("serverName")
175+
library_name = score_execution.get("outputTable").get("libraryName")
176+
table_name = score_execution.get("outputTable").get("tableName")
177+
178+
# If swat is not available, then
179+
if not swat:
180+
output_table = cls._no_gateway_get_results(
181+
server_name, library_name, table_name
182+
)
183+
return output_table
184+
else:
185+
session = current_session()
186+
cas = session.as_swat()
187+
response = cas.loadActionSet("gateway")
188+
if not response:
189+
output_table = cls._no_gateway_get_results(
190+
server_name, library_name, table_name
191+
)
192+
return output_table
193+
else:
194+
gateway_code = f"""
195+
import pandas as pd
196+
import numpy as np
197+
198+
table = gateway.read_table({{"caslib": "{library_name}", "name": "{table_name}"}})
199+
200+
gateway.return_table("Execution Results", df = table, label = "label", title = "title")"""
201+
202+
output_table = cas.gateway.runlang(
203+
code=gateway_code, single=True, timeout_millis=10000
204+
)
205+
output_table = pd.DataFrame(output_table["Execution Results"])
206+
return output_table
207+
208+
@classmethod
209+
def _no_gateway_get_results(cls, server_name, library_name, table_name):
210+
"""Helper method that builds the output table.
211+
212+
Parameters
213+
--------
214+
server_name: str
215+
CAS server where original table is stored.
216+
library_name: CAS library where original table is stored.
217+
table_name: Table that contains row and columns information to build the output table
218+
219+
Returns
220+
-------
221+
Pandas Dataframe
222+
223+
"""
224+
if pd.__version__ >= StrictVersion("1.0.3"):
225+
from pandas import json_normalize
226+
else:
227+
from pandas.io.json import json_normalize
228+
229+
warnings.warn(
230+
"Without swat installed, the amount of rows from the output table that "
231+
"can be collected are memory limited by the CAS worker."
232+
)
233+
234+
output_columns = cls._cas_management.get(
235+
f"servers/{server_name}/"
236+
f"caslibs/{library_name}/"
237+
f"tables/{table_name}/columns?limit=10000"
238+
)
239+
columns = json_normalize(output_columns.json(), "items")
240+
column_names = columns["names"].to_list()
241+
242+
output_rows = cls._services.get(
243+
f"casRowSets/servers/{server_name}"
244+
f"caslibs/{library_name}"
245+
f"tables/{table_name}/rows?limit=10000"
246+
)
247+
output_table = pd.DataFrame(
248+
json_normalize(output_rows.json()["items"])["cells"].to_list(),
249+
columns=column_names,
250+
)
251+
return output_table

0 commit comments

Comments
 (0)