Skip to content

Commit 9a09a1e

Browse files
authored
[Validate] Create custom evaluation functions (#305)
1 parent 0cc63f9 commit 9a09a1e

File tree

7 files changed

+100
-9
lines changed

7 files changed

+100
-9
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ All notable changes to the [Nucleus Python Client](https://github.com/scaleapi/n
55
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

8+
## [0.12.0](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.12.0) - 2022-05-27
9+
10+
### Added
11+
12+
- Allow users to create placeholder evaluation functions for Scenario Tests in Validate
13+
14+
815
## [0.11.2](https://github.com/scaleapi/nucleus-python-client/releases/tag/v0.11.2) - 2022-05-20
916

1017
### Changed

nucleus/validate/client.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33
from nucleus.connection import Connection
44
from nucleus.job import AsyncJob
55

6-
from .constants import SCENARIO_TEST_ID_KEY
7-
from .data_transfer_objects.eval_function import GetEvalFunctions
6+
from .constants import EVAL_FUNCTION_KEY, SCENARIO_TEST_ID_KEY
7+
from .data_transfer_objects.eval_function import (
8+
CreateEvalFunction,
9+
EvalFunctionEntry,
10+
GetEvalFunctions,
11+
)
812
from .data_transfer_objects.scenario_test import (
913
CreateScenarioTestRequest,
1014
EvalFunctionListEntry,
@@ -175,3 +179,34 @@ def evaluate_model_on_scenario_tests(
175179
f"validate/{model_id}/evaluate",
176180
)
177181
return AsyncJob.from_json(response, self.connection)
182+
183+
def create_external_eval_function(
184+
self,
185+
name: str,
186+
) -> EvalFunctionEntry:
187+
"""Creates a new external evaluation function. This external function can be used to upload evaluation
188+
results with functions defined and computed by the customer, without having to share the source code of the
189+
respective function.
190+
191+
Args:
192+
name: unique name of evaluation function
193+
194+
Raises:
195+
- NucleusAPIError if the creation of the function fails on the server side
196+
- ValidationError if the evaluation name is not well defined
197+
198+
Returns:
199+
Created EvalFunctionConfig object.
200+
201+
"""
202+
203+
response = self.connection.post(
204+
CreateEvalFunction(
205+
name=name,
206+
is_external_function=True,
207+
serialized_fn=None,
208+
raw_source=None,
209+
).dict(),
210+
"validate/eval_fn",
211+
)
212+
return EvalFunctionEntry.parse_obj(response[EVAL_FUNCTION_KEY])

nucleus/validate/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from enum import Enum
22

3+
EVAL_FUNCTION_KEY = "eval_fn"
34
EVALUATION_ID_KEY = "evaluation_id"
45
EVAL_FUNCTION_ID_KEY = "eval_function_id"
56
ID_KEY = "id"

nucleus/validate/data_transfer_objects/eval_function.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class EvalFunctionEntry(ImmutableModel):
7272
id: str
7373
name: str
7474
is_public: bool
75+
is_external_function: bool = False
7576
user_id: str
7677
serialized_fn: Optional[str] = None
7778
raw_source: Optional[str] = None
@@ -81,3 +82,24 @@ class GetEvalFunctions(ImmutableModel):
8182
"""Expected format from GET validate/eval_fn"""
8283

8384
eval_functions: List[EvalFunctionEntry]
85+
86+
87+
class CreateEvalFunction(ImmutableModel):
88+
"""Expected payload to POST validate/eval_fn"""
89+
90+
name: str
91+
is_external_function: bool
92+
serialized_fn: Optional[str] = None
93+
raw_source: Optional[str] = None
94+
95+
@validator("name")
96+
def name_is_valid(cls, v): # pylint: disable=no-self-argument
97+
if " " in v:
98+
raise ValueError(
99+
f"No spaces allowed in an evaluation function name, got '{v}'"
100+
)
101+
if len(v) == 0 or len(v) > 255:
102+
raise ValueError(
103+
"Name of evaluation function must be between 1-255 characters long"
104+
)
105+
return v

nucleus/validate/eval_functions/available_eval_functions.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,10 +1131,19 @@ class CustomEvalFunction(EvalFunctionConfig):
11311131
@classmethod
11321132
def expected_name(cls) -> str:
11331133
raise NotImplementedError(
1134-
"Custm evaluation functions are coming soon"
1134+
"Custom evaluation functions are coming soon"
11351135
) # Placeholder: See super().eval_func_entry for actual name
11361136

11371137

1138+
class ExternalEvalFunction(EvalFunctionConfig):
1139+
def __call__(self, **kwargs):
1140+
raise NotImplementedError("Cannot call an external function")
1141+
1142+
@classmethod
1143+
def expected_name(cls) -> str:
1144+
return "external_function"
1145+
1146+
11381147
class StandardEvalFunction(EvalFunctionConfig):
11391148
"""Class for standard Model CI eval functions that have not been added as attributes on
11401149
AvailableEvalFunctions yet.
@@ -1186,6 +1195,7 @@ def expected_name(cls) -> str:
11861195
CuboidPrecisionConfig,
11871196
CategorizationF1Config,
11881197
CustomEvalFunction,
1198+
ExternalEvalFunction,
11891199
EvalFunctionNotAvailable,
11901200
StandardEvalFunction,
11911201
PolygonMAPConfig,
@@ -1229,7 +1239,12 @@ def __init__(self, available_functions: List[EvalFunctionEntry]):
12291239
self._custom_to_function: Dict[str, CustomEvalFunction] = {
12301240
f.name: CustomEvalFunction(f)
12311241
for f in available_functions
1232-
if not f.is_public
1242+
if not f.is_public and not f.is_external_function
1243+
}
1244+
self._external_to_function: Dict[str, ExternalEvalFunction] = {
1245+
f.name: ExternalEvalFunction(f)
1246+
for f in available_functions
1247+
if f.is_external_function
12331248
}
12341249
self.bbox_iou: BoundingBoxIOUConfig = (
12351250
self._assign_eval_function_if_defined(BoundingBoxIOUConfig)
@@ -1294,8 +1309,9 @@ def __repr__(self):
12941309
str(name).lower() for name in self._public_func_entries.keys()
12951310
]
12961311
return (
1297-
f"<AvailableEvaluationFunctions: public:{functions_lower} "
1298-
f"private: {list(self._custom_to_function.keys())}"
1312+
f"<AvailableEvaluationFunctions: public: {functions_lower} "
1313+
f"private: {list(self._custom_to_function.keys())} "
1314+
f"external: {list(self._external_to_function.keys())} "
12991315
)
13001316

13011317
@property
@@ -1312,13 +1328,22 @@ def public_functions(self) -> Dict[str, EvalFunctionConfig]:
13121328

13131329
@property
13141330
def private_functions(self) -> Dict[str, CustomEvalFunction]:
1315-
"""Custom functions uploaded to Model CI
1331+
"""Private functions uploaded to Model CI
13161332
13171333
Returns:
13181334
Dict of function name to :class:`CustomEvalFunction`.
13191335
"""
13201336
return self._custom_to_function
13211337

1338+
@property
1339+
def external_functions(self) -> Dict[str, ExternalEvalFunction]:
1340+
"""External functions uploaded to Model CI
1341+
1342+
Returns:
1343+
Dict of function name to :class:`ExternalEvalFunction`.
1344+
"""
1345+
return self._external_to_function
1346+
13221347
def _assign_eval_function_if_defined(
13231348
self,
13241349
eval_function_constructor: Callable[[EvalFunctionEntry], EvalFunction],
@@ -1340,6 +1365,7 @@ def from_id(self, eval_function_id: str):
13401365
for eval_func in itertools.chain(
13411366
self._public_to_function.values(),
13421367
self._custom_to_function.values(),
1368+
self._external_to_function.values(),
13431369
):
13441370
if eval_func.id == eval_function_id:
13451371
return eval_func

nucleus/validate/scenario_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def add_eval_function(
9393
).dict(),
9494
"validate/scenario_test_eval_function",
9595
)
96-
print(response)
96+
9797
return ScenarioTestMetric(
9898
scenario_test_id=response[SCENARIO_TEST_ID_KEY],
9999
eval_function_id=response[EVAL_FUNCTION_ID_KEY],

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = '''
2121

2222
[tool.poetry]
2323
name = "scale-nucleus"
24-
version = "0.11.2"
24+
version = "0.12.0"
2525
description = "The official Python client library for Nucleus, the Data Platform for AI"
2626
license = "MIT"
2727
authors = ["Scale AI Nucleus Team <nucleusapi@scaleapi.com>"]

0 commit comments

Comments
 (0)