Skip to content

Commit 58b714d

Browse files
authored
[PLT-312] add support for embeddings (#1534)
1 parent f27928e commit 58b714d

File tree

7 files changed

+617
-17
lines changed

7 files changed

+617
-17
lines changed
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {},
5+
"cells": [
6+
{
7+
"metadata": {},
8+
"source": [
9+
"<td>\n",
10+
" <a target=\"_blank\" href=\"https://labelbox.com\" ><img src=\"https://labelbox.com/blog/content/images/2021/02/logo-v4.svg\" width=256/></a>\n",
11+
"</td>"
12+
],
13+
"cell_type": "markdown"
14+
},
15+
{
16+
"metadata": {},
17+
"source": [
18+
"<td>\n",
19+
"<a href=\"https://colab.research.google.com/github/Labelbox/labelbox-python/blob/master/examples/basics/custom_embeddings.ipynb\" target=\"_blank\"><img\n",
20+
"src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"></a>\n",
21+
"</td>\n",
22+
"\n",
23+
"<td>\n",
24+
"<a href=\"https://github.com/Labelbox/labelbox-python/tree/master/examples/basics/custom_embeddings.ipynb\" target=\"_blank\"><img\n",
25+
"src=\"https://img.shields.io/badge/GitHub-100000?logo=github&logoColor=white\" alt=\"GitHub\"></a>\n",
26+
"</td>"
27+
],
28+
"cell_type": "markdown"
29+
},
30+
{
31+
"cell_type": "markdown",
32+
"source": [
33+
"# Custom Embeddings\n",
34+
"\n",
35+
"You can improve your data exploration and similarity search experience by adding your own custom embeddings. Labelbox allows you to upload up to 100 different custom embeddings on any kind of data. You can experiment with different embeddings to power your data selection."
36+
],
37+
"metadata": {
38+
"collapsed": false
39+
}
40+
},
41+
{
42+
"metadata": {},
43+
"source": [
44+
"# Setup"
45+
],
46+
"cell_type": "markdown"
47+
},
48+
{
49+
"metadata": {},
50+
"source": [
51+
"!pip3 install -q \"labelbox\""
52+
],
53+
"cell_type": "code",
54+
"outputs": [],
55+
"execution_count": null
56+
},
57+
{
58+
"metadata": {},
59+
"source": [
60+
"import labelbox as lb\n",
61+
"import numpy as np\n",
62+
"import json"
63+
],
64+
"cell_type": "code",
65+
"outputs": [],
66+
"execution_count": null
67+
},
68+
{
69+
"metadata": {},
70+
"source": [
71+
"API_KEY = \"\"\n",
72+
"client = lb.Client(API_KEY)"
73+
],
74+
"cell_type": "code",
75+
"outputs": [],
76+
"execution_count": null
77+
},
78+
{
79+
"metadata": {},
80+
"source": [
81+
"# Select data rows in Labelbox for custom embeddings"
82+
],
83+
"cell_type": "markdown"
84+
},
85+
{
86+
"metadata": {},
87+
"source": [
88+
"client.enable_experimental = True\n",
89+
"\n",
90+
"# get images from a Labelbox dataset\n",
91+
"# Our systems start to process data after 1000 embeddings of each type, for this demo make sure your dataset is over 1000 data rows\n",
92+
"dataset = client.get_dataset(\"<ADD YOUR DATASET ID>\")\n",
93+
"\n",
94+
"export_task = dataset.export()\n",
95+
"export_task.wait_till_done()"
96+
],
97+
"cell_type": "code",
98+
"outputs": [],
99+
"execution_count": null
100+
},
101+
{
102+
"metadata": {},
103+
"source": [
104+
"data_rows = []\n",
105+
"\n",
106+
"def json_stream_handler(output: lb.JsonConverterOutput):\n",
107+
" data_row = json.loads(output.json_str)\n",
108+
" data_rows.append(data_row)\n",
109+
"\n",
110+
"if export_task.has_errors():\n",
111+
" export_task.get_stream(\n",
112+
" converter=lb.JsonConverter(),\n",
113+
" stream_type=lb.StreamType.ERRORS\n",
114+
" ).start(stream_handler=lambda error: print(error))\n",
115+
"\n",
116+
"if export_task.has_result():\n",
117+
" export_json = export_task.get_stream(\n",
118+
" converter=lb.JsonConverter(),\n",
119+
" stream_type=lb.StreamType.RESULT\n",
120+
" ).start(stream_handler=json_stream_handler)"
121+
],
122+
"cell_type": "code",
123+
"outputs": [],
124+
"execution_count": null
125+
},
126+
{
127+
"metadata": {},
128+
"source": [
129+
"data_row_ids = [dr[\"data_row\"][\"id\"] for dr in data_rows]\n",
130+
"\n",
131+
"data_row_ids = data_row_ids[:1000] # keep the first 1000 examples for the sake of this demo"
132+
],
133+
"cell_type": "code",
134+
"outputs": [],
135+
"execution_count": null
136+
},
137+
{
138+
"metadata": {},
139+
"source": [
140+
"# Create the payload for custom embeddings\n",
141+
"-- It should be a .ndjson file. \n",
142+
"-- Every line is a json file that finishes with a \\n character. \n",
143+
"-- It does not have to be created through Python. "
144+
],
145+
"cell_type": "markdown"
146+
},
147+
{
148+
"metadata": {},
149+
"source": [
150+
"nb_data_rows = len(data_row_ids)\n",
151+
"print(\"Number of data rows: \", nb_data_rows)\n",
152+
"# Generate random vectors, of dimension 2048 each\n",
153+
"# Labelbox supports custom embedding vectors of dimension up to 2048\n",
154+
"custom_embeddings = [list(np.random.random(2048)) for _ in range(nb_data_rows)]"
155+
],
156+
"cell_type": "code",
157+
"outputs": [],
158+
"execution_count": null
159+
},
160+
{
161+
"metadata": {},
162+
"source": [
163+
"# Create the payload for custom embeddings\n",
164+
"payload = []\n",
165+
"for data_row_id,custom_embedding in zip(data_row_ids,custom_embeddings):\n",
166+
" payload.append({\"id\": data_row_id, \"vector\": custom_embedding})\n",
167+
"\n",
168+
"print('payload', len(payload),payload[:1])"
169+
],
170+
"cell_type": "code",
171+
"outputs": [],
172+
"execution_count": null
173+
},
174+
{
175+
"metadata": {},
176+
"source": [
177+
"# Delete any pre-existing file\n",
178+
"import os\n",
179+
"if os.path.exists(\"payload.ndjson\"):\n",
180+
" os.remove(\"payload.ndjson\")\n",
181+
"\n",
182+
"# Convert the payload to a JSON file\n",
183+
"with open('payload.ndjson', 'w') as f:\n",
184+
" for p in payload:\n",
185+
" f.write(json.dumps(p) + \"\\n\")\n",
186+
" # sanity_check_payload = json.dump(payload, f)"
187+
],
188+
"cell_type": "code",
189+
"outputs": [],
190+
"execution_count": null
191+
},
192+
{
193+
"metadata": {},
194+
"source": [
195+
"# Sanity check that you can read/load the file and the payload is correct\n",
196+
"with open('payload.ndjson') as f:\n",
197+
" sanity_check_payload = [json.loads(l) for l in f.readlines()]\n",
198+
"print(\"Nb of custom embedding vectors in sanity_check_payload: \", len(sanity_check_payload))"
199+
],
200+
"cell_type": "code",
201+
"outputs": [],
202+
"execution_count": null
203+
},
204+
{
205+
"metadata": {},
206+
"source": [
207+
"# See all custom embeddings available in your Labelbox workspace\n",
208+
"embeddings = client.get_embeddings()"
209+
],
210+
"cell_type": "code",
211+
"outputs": [],
212+
"execution_count": null
213+
},
214+
{
215+
"metadata": {},
216+
"source": [
217+
"# Create a new custom embedding, unless you want to re-use one\n",
218+
"embedding = client.create_embedding(\"my_custom_embedding_2048_dimensions\", 2048)"
219+
],
220+
"cell_type": "code",
221+
"outputs": [],
222+
"execution_count": null
223+
},
224+
{
225+
"metadata": {},
226+
"source": [
227+
"# Delete a custom embedding\n",
228+
"embedding.delete()"
229+
],
230+
"cell_type": "code",
231+
"outputs": [],
232+
"execution_count": null
233+
},
234+
{
235+
"metadata": {},
236+
"source": [
237+
"# Upload the payload to Labelbox"
238+
],
239+
"cell_type": "markdown"
240+
},
241+
{
242+
"metadata": {},
243+
"source": [
244+
"# Replace the current id with the newly generated id from the previous step, or any existing custom embedding id\n",
245+
"embedding.import_vectors_from_file(\"./payload.ndjson\")"
246+
],
247+
"cell_type": "code",
248+
"outputs": [],
249+
"execution_count": null
250+
},
251+
{
252+
"metadata": {},
253+
"source": [
254+
"# Get the count of imported vectors for a custom embedding"
255+
],
256+
"cell_type": "markdown"
257+
},
258+
{
259+
"metadata": {},
260+
"source": [
261+
"# Count how many data rows have a specific custom embedding (this can take a couple of minutes)\n",
262+
"count = embedding.get_imported_vector_count()"
263+
],
264+
"cell_type": "code",
265+
"outputs": [],
266+
"execution_count": null
267+
}
268+
]
269+
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import io
2+
import json
3+
import logging
4+
from typing import Dict, Any, Optional, List, Callable
5+
from urllib.parse import urlparse
6+
7+
import requests
8+
from requests import Session, Response
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class AdvClient:
14+
15+
def __init__(self, endpoint: str, api_key: str):
16+
self.endpoint = endpoint
17+
self.api_key = api_key
18+
self.session = self._create_session()
19+
20+
def create_embedding(self, name: str, dims: int) -> Dict[str, Any]:
21+
data = {"name": name, "dims": dims}
22+
return self._request("POST", "/adv/v1/embeddings", data)
23+
24+
def delete_embedding(self, id: str):
25+
return self._request("DELETE", f"/adv/v1/embeddings/{id}")
26+
27+
def get_embedding(self, id: str) -> Dict[str, Any]:
28+
return self._request("GET", f"/adv/v1/embeddings/{id}")
29+
30+
def get_embeddings(self) -> List[Dict[str, Any]]:
31+
return self._request("GET", "/adv/v1/embeddings").get("results", [])
32+
33+
def import_vectors_from_file(self, id: str, file_path: str, callback=None):
34+
self._send_ndjson(f"/adv/v1/embeddings/{id}/_import_ndjson", file_path,
35+
callback)
36+
37+
def get_imported_vector_count(self, id: str) -> int:
38+
data = self._request("GET", f"/adv/v1/embeddings/{id}/vectors/_count")
39+
return data.get("count", 0)
40+
41+
def _create_session(self) -> Session:
42+
session = requests.session()
43+
session.headers.update({
44+
"Authorization": f"Bearer {self.api_key}",
45+
"Content-Type": "application/json"
46+
})
47+
return session
48+
49+
def _request(self,
50+
method: str,
51+
path: str,
52+
data: Optional[Dict[str, Any]] = None,
53+
headers: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
54+
url = f"{self.endpoint}{path}"
55+
requests_data = None
56+
if data:
57+
requests_data = json.dumps(data)
58+
response = self.session.request(method,
59+
url,
60+
data=requests_data,
61+
headers=headers)
62+
response.raise_for_status()
63+
return response.json()
64+
65+
def _send_ndjson(self,
66+
path: str,
67+
file_path: str,
68+
callback: Optional[Callable[[Dict[str, Any]],
69+
None]] = None):
70+
"""
71+
Sends an NDJson file in chunks.
72+
73+
Args:
74+
path: The URL path
75+
file_path: The path to the NDJSON file.
76+
callback: A callback to run for each chunk uploaded.
77+
"""
78+
79+
def upload_chunk(_buffer, _count):
80+
_buffer.write(b"\n")
81+
_headers = {
82+
"Content-Type": "application/x-ndjson",
83+
"X-Content-Lines": str(_count),
84+
"Content-Length": str(buffer.tell())
85+
}
86+
rsp = self._send_bytes(f"{self.endpoint}{path}", _buffer, _headers)
87+
rsp.raise_for_status()
88+
if callback:
89+
callback(rsp.json())
90+
91+
buffer = io.BytesIO()
92+
count = 0
93+
with open(file_path, 'rb') as fp:
94+
for line in fp:
95+
buffer.write(line)
96+
count += 1
97+
if count >= 1000:
98+
upload_chunk(buffer, count)
99+
buffer = io.BytesIO()
100+
count = 0
101+
if count:
102+
upload_chunk(buffer, count)
103+
104+
def _send_bytes(self,
105+
url: str,
106+
buffer: io.BytesIO,
107+
headers: Optional[Dict[str, Any]] = None) -> Response:
108+
buffer.seek(0)
109+
return self.session.put(url, headers=headers, data=buffer)
110+
111+
@classmethod
112+
def factory(cls, api_endpoint: str, api_key: str) -> "AdvClient":
113+
parsed_url = urlparse(api_endpoint)
114+
endpoint = f"{parsed_url.scheme}://{parsed_url.netloc}/adv"
115+
return AdvClient(endpoint, api_key)

0 commit comments

Comments
 (0)