Skip to content

Commit e3f99dd

Browse files
author
Vinjai Vale
committed
Add test scripts
1 parent d3ee877 commit e3f99dd

File tree

3 files changed

+536
-0
lines changed

3 files changed

+536
-0
lines changed

nucleus/autocurate.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ def entropy(name, model_runs, client):
1212
assert (
1313
len({model_run.dataset_id for model_run in model_runs}) == 1
1414
), f"Model runs have conflicting dataset ids: {model_runs}"
15+
# TODO: support multiple model runs
16+
assert (
17+
len(model_runs) == 1
18+
), "Entropy currently not supported for multiple model runs"
1519
model_run_ids = [model_run.model_run_id for model_run in model_runs]
1620
dataset_id = model_runs[0].dataset_id
1721
response = client.make_request(

scripts/autocurate_bdd.ipynb

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 6,
6+
"id": "d5f245bd-0ca9-4a1c-ba2c-bc952f9eb09f",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import nucleus\n",
11+
"import nucleus.autocurate"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 2,
17+
"id": "4c26e04c-2346-4be1-98a8-fe6dced7f8d1",
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"API_KEY = 'test_47f6394c4822426389461f36334a45ff' # Vinjai's API key\n",
22+
"client = nucleus.NucleusClient(API_KEY)"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 111,
28+
"id": "fb4daca6-e217-485f-bdd3-ac9a60b48842",
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"run = client.get_model_run(model_run_id='run_c4s5wq257tf00d9m3wsg', dataset_id='ds_c4s5prvm7v0007rf4vag')"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": 112,
38+
"id": "904c4570-bfda-4bb3-9c46-eb44ff62d5f3",
39+
"metadata": {},
40+
"outputs": [],
41+
"source": [
42+
"job = nucleus.autocurate.entropy(\"Mean Entropy\", [run], client)"
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": 114,
48+
"id": "d5ff24e8-d960-4ddc-a2ea-6cd14809de4c",
49+
"metadata": {},
50+
"outputs": [
51+
{
52+
"data": {
53+
"text/plain": [
54+
"{'job_id': 'job_c4s61m6wm91bgxksmggg',\n",
55+
" 'status': 'Completed',\n",
56+
" 'message': {'status_log': 'No additional information can be provided at this time.'},\n",
57+
" 'job_progress': '0.00',\n",
58+
" 'completed_steps': 0,\n",
59+
" 'total_steps': 0}"
60+
]
61+
},
62+
"execution_count": 114,
63+
"metadata": {},
64+
"output_type": "execute_result"
65+
}
66+
],
67+
"source": [
68+
"job.status()"
69+
]
70+
},
71+
{
72+
"cell_type": "code",
73+
"execution_count": null,
74+
"id": "76bf3af8-6fa3-4bcd-abf6-2184da9ad1b2",
75+
"metadata": {},
76+
"outputs": [],
77+
"source": []
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"id": "87504db0-4fe2-4588-a929-2f45957dfaf4",
83+
"metadata": {},
84+
"outputs": [],
85+
"source": []
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": 93,
90+
"id": "99894c0b-ee0f-4721-84b6-8c6881a52265",
91+
"metadata": {},
92+
"outputs": [],
93+
"source": [
94+
"dataset = client.get_dataset('ds_c4s4cx39h33g09hm2hh0')\n",
95+
"model = client.get_model('prj_c4s4ent9m78g099966j0')"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": 94,
101+
"id": "6b6ef647-9c2e-4195-a0ed-56ba79518762",
102+
"metadata": {},
103+
"outputs": [
104+
{
105+
"name": "stderr",
106+
"output_type": "stream",
107+
"text": [
108+
"0it [00:00, ?it/s]\n"
109+
]
110+
}
111+
],
112+
"source": [
113+
"run = model.create_run(name='test', dataset=dataset, predictions=[])"
114+
]
115+
},
116+
{
117+
"cell_type": "code",
118+
"execution_count": 95,
119+
"id": "a3b41138-5bce-46af-82fe-b8c47fd57290",
120+
"metadata": {},
121+
"outputs": [],
122+
"source": [
123+
"pred = [nucleus.BoxPrediction(label='traffic sign', x=690, y=337, width=96, height=40, reference_id='34f3e55f-9919bb47.jpg', class_pdf={'person': 0.00022668440182883353,\n",
124+
" 'rider': 0.00012657202022136208,\n",
125+
" 'car': 0.001457674309916847,\n",
126+
" 'truck': 0.004805948302644584,\n",
127+
" 'bus': 0.015461443014798156,\n",
128+
" 'train': 0.0003142998148843214,\n",
129+
" 'motor': 0.00019848517426185146,\n",
130+
" 'bike': 0.0002754959884690579,\n",
131+
" 'traffic light': 0.0006128712862352474,\n",
132+
" 'traffic sign': 0.9765205256867397}, confidence=0.9765205256867397)]"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": 97,
138+
"id": "8602b078-803d-40a7-bafd-41067fa94d07",
139+
"metadata": {},
140+
"outputs": [
141+
{
142+
"name": "stderr",
143+
"output_type": "stream",
144+
"text": [
145+
"100%|██████████| 1/1 [00:04<00:00, 4.73s/it]\n"
146+
]
147+
},
148+
{
149+
"data": {
150+
"text/plain": [
151+
"{'model_run_id': 'run_c4s4ktqwm91bgxksm0e0',\n",
152+
" 'predictions_processed': 1,\n",
153+
" 'predictions_ignored': 0}"
154+
]
155+
},
156+
"execution_count": 97,
157+
"metadata": {},
158+
"output_type": "execute_result"
159+
}
160+
],
161+
"source": [
162+
"run.predict(pred)"
163+
]
164+
},
165+
{
166+
"cell_type": "code",
167+
"execution_count": 100,
168+
"id": "da741f8d-36d1-4fa7-84ce-5c56428d652c",
169+
"metadata": {},
170+
"outputs": [
171+
{
172+
"data": {
173+
"text/plain": [
174+
"{'person': 0.00022668440182883353,\n",
175+
" 'rider': 0.00012657202022136208,\n",
176+
" 'car': 0.001457674309916847,\n",
177+
" 'truck': 0.004805948302644584,\n",
178+
" 'bus': 0.015461443014798156,\n",
179+
" 'train': 0.0003142998148843214,\n",
180+
" 'motor': 0.00019848517426185146,\n",
181+
" 'bike': 0.0002754959884690579,\n",
182+
" 'traffic light': 0.0006128712862352474,\n",
183+
" 'traffic sign': 0.9765205256867397}"
184+
]
185+
},
186+
"execution_count": 100,
187+
"metadata": {},
188+
"output_type": "execute_result"
189+
}
190+
],
191+
"source": [
192+
"run.ungrouped_export()['box'][0].class_pdf"
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": null,
198+
"id": "1c51f65b-39e9-4814-b36c-21e9d319d610",
199+
"metadata": {},
200+
"outputs": [],
201+
"source": []
202+
}
203+
],
204+
"metadata": {
205+
"kernelspec": {
206+
"display_name": "Python 3",
207+
"language": "python",
208+
"name": "python3"
209+
},
210+
"language_info": {
211+
"codemirror_mode": {
212+
"name": "ipython",
213+
"version": 3
214+
},
215+
"file_extension": ".py",
216+
"mimetype": "text/x-python",
217+
"name": "python",
218+
"nbconvert_exporter": "python",
219+
"pygments_lexer": "ipython3",
220+
"version": "3.8.10"
221+
}
222+
},
223+
"nbformat": 4,
224+
"nbformat_minor": 5
225+
}

0 commit comments

Comments
 (0)