Skip to content

Commit 99ee5f1

Browse files
authored
[PLT-1463] Removed ND deserialize from some unit test part 2 (#1815)
1 parent 37c038e commit 99ee5f1

File tree

11 files changed

+1593
-196
lines changed

11 files changed

+1593
-196
lines changed
Lines changed: 149 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,166 @@
11
import json
22

3+
from labelbox.data.annotation_types.data.generic_data_row_data import (
4+
GenericDataRowData,
5+
)
6+
from labelbox.data.annotation_types.metrics.confusion_matrix import (
7+
ConfusionMatrixMetric,
8+
)
39
from labelbox.data.serialization.ndjson.converter import NDJsonConverter
10+
from labelbox.types import (
11+
Label,
12+
ScalarMetric,
13+
ScalarMetricAggregation,
14+
ConfusionMatrixAggregation,
15+
)
416

517

618
def test_metric():
719
with open("tests/data/assets/ndjson/metric_import.json", "r") as file:
820
data = json.load(file)
921

10-
label_list = list(NDJsonConverter.deserialize(data))
11-
reserialized = list(NDJsonConverter.serialize(label_list))
12-
assert reserialized == data
22+
labels = [
23+
Label(
24+
data=GenericDataRowData(
25+
uid="ckrmdnqj4000007msh9p2a27r",
26+
),
27+
annotations=[
28+
ScalarMetric(
29+
value=0.1,
30+
extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672"},
31+
aggregation=ScalarMetricAggregation.ARITHMETIC_MEAN,
32+
)
33+
],
34+
)
35+
]
36+
37+
res = list(NDJsonConverter.serialize(labels))
38+
assert res == data
1339

1440

1541
def test_custom_scalar_metric():
16-
with open(
17-
"tests/data/assets/ndjson/custom_scalar_import.json", "r"
18-
) as file:
19-
data = json.load(file)
42+
data = [
43+
{
44+
"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672",
45+
"dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
46+
"metricValue": 0.1,
47+
"metricName": "custom_iou",
48+
"featureName": "sample_class",
49+
"subclassName": "sample_subclass",
50+
"aggregation": "SUM",
51+
},
52+
{
53+
"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673",
54+
"dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
55+
"metricValue": 0.1,
56+
"metricName": "custom_iou",
57+
"featureName": "sample_class",
58+
"aggregation": "SUM",
59+
},
60+
{
61+
"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674",
62+
"dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
63+
"metricValue": {0.1: 0.1, 0.2: 0.5},
64+
"metricName": "custom_iou",
65+
"aggregation": "SUM",
66+
},
67+
]
68+
69+
labels = [
70+
Label(
71+
data=GenericDataRowData(
72+
uid="ckrmdnqj4000007msh9p2a27r",
73+
),
74+
annotations=[
75+
ScalarMetric(
76+
value=0.1,
77+
feature_name="sample_class",
78+
subclass_name="sample_subclass",
79+
extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672"},
80+
metric_name="custom_iou",
81+
aggregation=ScalarMetricAggregation.SUM,
82+
),
83+
ScalarMetric(
84+
value=0.1,
85+
feature_name="sample_class",
86+
extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673"},
87+
metric_name="custom_iou",
88+
aggregation=ScalarMetricAggregation.SUM,
89+
),
90+
ScalarMetric(
91+
value={"0.1": 0.1, "0.2": 0.5},
92+
extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674"},
93+
metric_name="custom_iou",
94+
aggregation=ScalarMetricAggregation.SUM,
95+
),
96+
],
97+
)
98+
]
99+
100+
res = list(NDJsonConverter.serialize(labels))
20101

21-
label_list = list(NDJsonConverter.deserialize(data))
22-
reserialized = list(NDJsonConverter.serialize(label_list))
23-
assert json.dumps(reserialized, sort_keys=True) == json.dumps(
24-
data, sort_keys=True
25-
)
102+
assert res == data
26103

27104

28105
def test_custom_confusion_matrix_metric():
29-
with open(
30-
"tests/data/assets/ndjson/custom_confusion_matrix_import.json", "r"
31-
) as file:
32-
data = json.load(file)
106+
data = [
107+
{
108+
"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672",
109+
"dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
110+
"metricValue": (1, 1, 2, 3),
111+
"metricName": "50%_iou",
112+
"featureName": "sample_class",
113+
"subclassName": "sample_subclass",
114+
"aggregation": "CONFUSION_MATRIX",
115+
},
116+
{
117+
"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673",
118+
"dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
119+
"metricValue": (0, 1, 2, 5),
120+
"metricName": "50%_iou",
121+
"featureName": "sample_class",
122+
"aggregation": "CONFUSION_MATRIX",
123+
},
124+
{
125+
"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674",
126+
"dataRow": {"id": "ckrmdnqj4000007msh9p2a27r"},
127+
"metricValue": {0.1: (0, 1, 2, 3), 0.2: (5, 3, 4, 3)},
128+
"metricName": "50%_iou",
129+
"aggregation": "CONFUSION_MATRIX",
130+
},
131+
]
132+
133+
labels = [
134+
Label(
135+
data=GenericDataRowData(
136+
uid="ckrmdnqj4000007msh9p2a27r",
137+
),
138+
annotations=[
139+
ConfusionMatrixMetric(
140+
value=(1, 1, 2, 3),
141+
feature_name="sample_class",
142+
subclass_name="sample_subclass",
143+
extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7672"},
144+
metric_name="50%_iou",
145+
aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX,
146+
),
147+
ConfusionMatrixMetric(
148+
value=(0, 1, 2, 5),
149+
feature_name="sample_class",
150+
extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7673"},
151+
metric_name="50%_iou",
152+
aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX,
153+
),
154+
ConfusionMatrixMetric(
155+
value={0.1: (0, 1, 2, 3), 0.2: (5, 3, 4, 3)},
156+
extra={"uuid": "a22bbf6e-b2da-4abe-9a11-df84759f7674"},
157+
metric_name="50%_iou",
158+
aggregation=ConfusionMatrixAggregation.CONFUSION_MATRIX,
159+
),
160+
],
161+
)
162+
]
163+
164+
res = list(NDJsonConverter.serialize(labels))
33165

34-
label_list = list(NDJsonConverter.deserialize(data))
35-
reserialized = list(NDJsonConverter.serialize(label_list))
36-
assert json.dumps(reserialized, sort_keys=True) == json.dumps(
37-
data, sort_keys=True
38-
)
166+
assert data == res
Lines changed: 109 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,125 @@
11
import json
22

3+
from labelbox.data.annotation_types.data.generic_data_row_data import (
4+
GenericDataRowData,
5+
)
36
import pytest
47

58
from labelbox.data.serialization import NDJsonConverter
9+
from labelbox.types import (
10+
Label,
11+
MessageEvaluationTaskAnnotation,
12+
MessageSingleSelectionTask,
13+
MessageMultiSelectionTask,
14+
MessageInfo,
15+
OrderedMessageInfo,
16+
MessageRankingTask,
17+
)
618

719

820
def test_message_task_annotation_serialization():
921
with open("tests/data/assets/ndjson/mmc_import.json", "r") as file:
1022
data = json.load(file)
1123

12-
deserialized = list(NDJsonConverter.deserialize(data))
13-
reserialized = list(NDJsonConverter.serialize(deserialized))
24+
labels = [
25+
Label(
26+
data=GenericDataRowData(
27+
uid="cnjencjencjfencvj",
28+
),
29+
annotations=[
30+
MessageEvaluationTaskAnnotation(
31+
name="single-selection",
32+
extra={"uuid": "c1be3a57-597e-48cb-8d8d-a852665f9e72"},
33+
value=MessageSingleSelectionTask(
34+
message_id="clxfzocbm00083b6v8vczsept",
35+
model_config_name="GPT 5",
36+
parent_message_id="clxfznjb800073b6v43ppx9ca",
37+
),
38+
)
39+
],
40+
),
41+
Label(
42+
data=GenericDataRowData(
43+
uid="cfcerfvergerfefj",
44+
),
45+
annotations=[
46+
MessageEvaluationTaskAnnotation(
47+
name="multi-selection",
48+
extra={"uuid": "gferf3a57-597e-48cb-8d8d-a8526fefe72"},
49+
value=MessageMultiSelectionTask(
50+
parent_message_id="clxfznjb800073b6v43ppx9ca",
51+
selected_messages=[
52+
MessageInfo(
53+
message_id="clxfzocbm00083b6v8vczsept",
54+
model_config_name="GPT 5",
55+
)
56+
],
57+
),
58+
)
59+
],
60+
),
61+
Label(
62+
data=GenericDataRowData(
63+
uid="cwefgtrgrthveferfferffr",
64+
),
65+
annotations=[
66+
MessageEvaluationTaskAnnotation(
67+
name="ranking",
68+
extra={"uuid": "hybe3a57-5gt7e-48tgrb-8d8d-a852dswqde72"},
69+
value=MessageRankingTask(
70+
parent_message_id="clxfznjb800073b6v43ppx9ca",
71+
ranked_messages=[
72+
OrderedMessageInfo(
73+
message_id="clxfzocbm00083b6v8vczsept",
74+
model_config_name="GPT 4 with temperature 0.7",
75+
order=1,
76+
),
77+
OrderedMessageInfo(
78+
message_id="clxfzocbm00093b6vx4ndisub",
79+
model_config_name="GPT 5",
80+
order=2,
81+
),
82+
],
83+
),
84+
)
85+
],
86+
),
87+
]
1488

15-
assert data == reserialized
89+
res = list(NDJsonConverter.serialize(labels))
1690

91+
assert res == data
1792

18-
def test_mesage_ranking_task_wrong_order_serialization():
19-
with open("tests/data/assets/ndjson/mmc_import.json", "r") as file:
20-
data = json.load(file)
21-
22-
some_ranking_task = next(
23-
task
24-
for task in data
25-
if task["messageEvaluationTask"]["format"] == "message-ranking"
26-
)
27-
some_ranking_task["messageEvaluationTask"]["data"]["rankedMessages"][0][
28-
"order"
29-
] = 3
3093

94+
def test_mesage_ranking_task_wrong_order_serialization():
3195
with pytest.raises(ValueError):
32-
list(NDJsonConverter.deserialize([some_ranking_task]))
96+
(
97+
Label(
98+
data=GenericDataRowData(
99+
uid="cwefgtrgrthveferfferffr",
100+
),
101+
annotations=[
102+
MessageEvaluationTaskAnnotation(
103+
name="ranking",
104+
extra={
105+
"uuid": "hybe3a57-5gt7e-48tgrb-8d8d-a852dswqde72"
106+
},
107+
value=MessageRankingTask(
108+
parent_message_id="clxfznjb800073b6v43ppx9ca",
109+
ranked_messages=[
110+
OrderedMessageInfo(
111+
message_id="clxfzocbm00093b6vx4ndisub",
112+
model_config_name="GPT 5",
113+
order=1,
114+
),
115+
OrderedMessageInfo(
116+
message_id="clxfzocbm00083b6v8vczsept",
117+
model_config_name="GPT 4 with temperature 0.7",
118+
order=1,
119+
),
120+
],
121+
),
122+
)
123+
],
124+
),
125+
)

libs/labelbox/tests/data/serialization/ndjson/test_ndlabel_subclass_matching.py

Lines changed: 0 additions & 19 deletions
This file was deleted.

0 commit comments

Comments
 (0)