Skip to content

Commit 5a295f9

Browse files
committed
updated arity checking for both load and apply functions; if a load function does not have a parameter; don't try to pass it one.
1 parent ada7f9c commit 5a295f9

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

adk/ADK.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@ def __init__(self, apply_func, load_func=None, client=None, manifest_path="model
1111
"""
1212
Creates the adk object
1313
:param apply_func: A required function that can have an arity of 1-2, depending on if loading occurs
14-
:param load_func: An optional supplier function used if load time events are required, has an arity of 0.
15-
:param client: A Algorithmia Client instance that might be user defined, and is used for interacting with a model manifest file; if defined.
14+
:param load_func: An optional supplier function used if load time events are required, if a model manifest is provided;
15+
the function may have a single `manifest` parameter to interact with the model manifest, otherwise must have no parameters.
16+
:param client: A Algorithmia Client instance that might be user defined,
17+
and is used for interacting with a model manifest file; if defined.
1618
:param manifest_path: A development / testing facing variable used to set the name and path
1719
"""
1820
self.FIFO_PATH = "/tmp/algoout"
1921
apply_args, _, _, _, _, _, _ = inspect.getfullargspec(apply_func)
22+
self.apply_arity = len(apply_args)
2023
if load_func:
2124
load_args, _, _, _, _, _, _ = inspect.getfullargspec(load_func)
22-
if len(load_args) > 2:
25+
self.load_arity = len(load_args)
26+
if self.load_arity > 2:
2327
raise Exception("load function may either have no parameters, or one parameter providing the manifest "
2428
"state.")
2529
self.load_func = load_func
@@ -35,9 +39,10 @@ def __init__(self, apply_func, load_func=None, client=None, manifest_path="model
3539

3640
def load(self):
3741
try:
38-
if self.load_func and self.manifest.available():
42+
if self.manifest.available():
3943
self.manifest.initialize()
40-
self.load_result = self.load_func(self.manifest)
44+
if self.load_func and self.load_arity == 1:
45+
self.load_result = self.load_func(self.manifest)
4146
elif self.load_func:
4247
self.load_result = self.load_func()
4348
except Exception as e:
@@ -51,7 +56,7 @@ def load(self):
5156

5257
def apply(self, payload):
5358
try:
54-
if self.load_result:
59+
if self.load_result and self.apply_arity == 2:
5560
apply_result = self.apply_func(payload, self.load_result)
5661
else:
5762
apply_result = self.apply_func(payload)

0 commit comments

Comments
 (0)