Skip to content

Commit 59a2b50

Browse files
authored
Merge pull request #114 from scaleapi/vinjai/autocurate
Add scripts folder + test notebooks
2 parents bb0e985 + f682c94 commit 59a2b50

File tree

5 files changed

+546
-11
lines changed

5 files changed

+546
-11
lines changed

nucleus/autocurate.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,9 @@
88
from nucleus.job import AsyncJob
99

1010

11-
def entropy(name, model_runs, client):
12-
assert (
13-
len({model_run.dataset_id for model_run in model_runs}) == 1
14-
), f"Model runs have conflicting dataset ids: {model_runs}"
15-
model_run_ids = [model_run.model_run_id for model_run in model_runs]
16-
dataset_id = model_runs[0].dataset_id
11+
def entropy(name, model_run, client):
12+
model_run_ids = [model_run.model_run_id]
13+
dataset_id = model_run.dataset_id
1714
response = client.make_request(
1815
payload={"modelRunIds": model_run_ids},
1916
route=f"autocurate/{dataset_id}/single_model_entropy/{name}",

scripts/autocurate_bdd.ipynb

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 6,
6+
"id": "d5f245bd-0ca9-4a1c-ba2c-bc952f9eb09f",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"import nucleus\n",
11+
"import nucleus.autocurate"
12+
]
13+
},
14+
{
15+
"cell_type": "code",
16+
"execution_count": 2,
17+
"id": "4c26e04c-2346-4be1-98a8-fe6dced7f8d1",
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"API_KEY = 'test_47f6394c4822426389461f36334a45ff' # Vinjai's API key\n",
22+
"client = nucleus.NucleusClient(API_KEY)"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 115,
28+
"id": "fb4daca6-e217-485f-bdd3-ac9a60b48842",
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"run = client.get_model_run(model_run_id='run_c4s9crbjwjzg0b9g3200', dataset_id='ds_c4s5prvm7v0007rf4vag')"
33+
]
34+
},
35+
{
36+
"cell_type": "code",
37+
"execution_count": null,
38+
"id": "eee9ef2e-62cd-4da4-9862-d3903f47e00d",
39+
"metadata": {},
40+
"outputs": [],
41+
"source": []
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": 116,
46+
"id": "904c4570-bfda-4bb3-9c46-eb44ff62d5f3",
47+
"metadata": {},
48+
"outputs": [],
49+
"source": [
50+
"job = nucleus.autocurate.entropy(\"Mean Entropy (with Nones)\", [run], client)"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": 117,
56+
"id": "d5ff24e8-d960-4ddc-a2ea-6cd14809de4c",
57+
"metadata": {},
58+
"outputs": [
59+
{
60+
"data": {
61+
"text/plain": [
62+
"{'job_id': 'job_c4s9n7wwm9105t7g22bg',\n",
63+
" 'status': 'Completed',\n",
64+
" 'message': {'status_log': 'No additional information can be provided at this time.'},\n",
65+
" 'job_progress': '0.00',\n",
66+
" 'completed_steps': 0,\n",
67+
" 'total_steps': 0}"
68+
]
69+
},
70+
"execution_count": 117,
71+
"metadata": {},
72+
"output_type": "execute_result"
73+
}
74+
],
75+
"source": [
76+
"job.status()"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": null,
82+
"id": "76bf3af8-6fa3-4bcd-abf6-2184da9ad1b2",
83+
"metadata": {},
84+
"outputs": [],
85+
"source": []
86+
},
87+
{
88+
"cell_type": "code",
89+
"execution_count": null,
90+
"id": "87504db0-4fe2-4588-a929-2f45957dfaf4",
91+
"metadata": {},
92+
"outputs": [],
93+
"source": []
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": 93,
98+
"id": "99894c0b-ee0f-4721-84b6-8c6881a52265",
99+
"metadata": {},
100+
"outputs": [],
101+
"source": [
102+
"dataset = client.get_dataset('ds_c4s4cx39h33g09hm2hh0')\n",
103+
"model = client.get_model('prj_c4s4ent9m78g099966j0')"
104+
]
105+
},
106+
{
107+
"cell_type": "code",
108+
"execution_count": 94,
109+
"id": "6b6ef647-9c2e-4195-a0ed-56ba79518762",
110+
"metadata": {},
111+
"outputs": [
112+
{
113+
"name": "stderr",
114+
"output_type": "stream",
115+
"text": [
116+
"0it [00:00, ?it/s]\n"
117+
]
118+
}
119+
],
120+
"source": [
121+
"run = model.create_run(name='test', dataset=dataset, predictions=[])"
122+
]
123+
},
124+
{
125+
"cell_type": "code",
126+
"execution_count": 95,
127+
"id": "a3b41138-5bce-46af-82fe-b8c47fd57290",
128+
"metadata": {},
129+
"outputs": [],
130+
"source": [
131+
"pred = [nucleus.BoxPrediction(label='traffic sign', x=690, y=337, width=96, height=40, reference_id='34f3e55f-9919bb47.jpg', class_pdf={'person': 0.00022668440182883353,\n",
132+
" 'rider': 0.00012657202022136208,\n",
133+
" 'car': 0.001457674309916847,\n",
134+
" 'truck': 0.004805948302644584,\n",
135+
" 'bus': 0.015461443014798156,\n",
136+
" 'train': 0.0003142998148843214,\n",
137+
" 'motor': 0.00019848517426185146,\n",
138+
" 'bike': 0.0002754959884690579,\n",
139+
" 'traffic light': 0.0006128712862352474,\n",
140+
" 'traffic sign': 0.9765205256867397}, confidence=0.9765205256867397)]"
141+
]
142+
},
143+
{
144+
"cell_type": "code",
145+
"execution_count": 97,
146+
"id": "8602b078-803d-40a7-bafd-41067fa94d07",
147+
"metadata": {},
148+
"outputs": [
149+
{
150+
"name": "stderr",
151+
"output_type": "stream",
152+
"text": [
153+
"100%|██████████| 1/1 [00:04<00:00, 4.73s/it]\n"
154+
]
155+
},
156+
{
157+
"data": {
158+
"text/plain": [
159+
"{'model_run_id': 'run_c4s4ktqwm91bgxksm0e0',\n",
160+
" 'predictions_processed': 1,\n",
161+
" 'predictions_ignored': 0}"
162+
]
163+
},
164+
"execution_count": 97,
165+
"metadata": {},
166+
"output_type": "execute_result"
167+
}
168+
],
169+
"source": [
170+
"run.predict(pred)"
171+
]
172+
},
173+
{
174+
"cell_type": "code",
175+
"execution_count": 100,
176+
"id": "da741f8d-36d1-4fa7-84ce-5c56428d652c",
177+
"metadata": {},
178+
"outputs": [
179+
{
180+
"data": {
181+
"text/plain": [
182+
"{'person': 0.00022668440182883353,\n",
183+
" 'rider': 0.00012657202022136208,\n",
184+
" 'car': 0.001457674309916847,\n",
185+
" 'truck': 0.004805948302644584,\n",
186+
" 'bus': 0.015461443014798156,\n",
187+
" 'train': 0.0003142998148843214,\n",
188+
" 'motor': 0.00019848517426185146,\n",
189+
" 'bike': 0.0002754959884690579,\n",
190+
" 'traffic light': 0.0006128712862352474,\n",
191+
" 'traffic sign': 0.9765205256867397}"
192+
]
193+
},
194+
"execution_count": 100,
195+
"metadata": {},
196+
"output_type": "execute_result"
197+
}
198+
],
199+
"source": [
200+
"run.ungrouped_export()['box'][0].class_pdf"
201+
]
202+
},
203+
{
204+
"cell_type": "code",
205+
"execution_count": null,
206+
"id": "1c51f65b-39e9-4814-b36c-21e9d319d610",
207+
"metadata": {},
208+
"outputs": [],
209+
"source": []
210+
}
211+
],
212+
"metadata": {
213+
"kernelspec": {
214+
"display_name": "Python 3",
215+
"language": "python",
216+
"name": "python3"
217+
},
218+
"language_info": {
219+
"codemirror_mode": {
220+
"name": "ipython",
221+
"version": 3
222+
},
223+
"file_extension": ".py",
224+
"mimetype": "text/x-python",
225+
"name": "python",
226+
"nbconvert_exporter": "python",
227+
"pygments_lexer": "ipython3",
228+
"version": "3.8.10"
229+
}
230+
},
231+
"nbformat": 4,
232+
"nbformat_minor": 5
233+
}

0 commit comments

Comments
 (0)