Skip to content

Commit dd2e3b4

Browse files
authored
Replace direct http connection with requests.Session() Using connection pool (#1733)
1 parent 725c364 commit dd2e3b4

File tree

5 files changed

+67
-77
lines changed

5 files changed

+67
-77
lines changed

libs/labelbox/src/labelbox/client.py

Lines changed: 52 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -115,16 +115,25 @@ def __init__(self,
115115
self.app_url = app_url
116116
self.endpoint = endpoint
117117
self.rest_endpoint = rest_endpoint
118+
self._data_row_metadata_ontology = None
119+
self._adv_client = AdvClient.factory(rest_endpoint, api_key)
120+
self._connection: requests.Session = self._init_connection()
121+
122+
def _init_connection(self) -> requests.Session:
123+
connection = requests.Session(
124+
) # using default connection pool size of 10
125+
connection.headers.update(self._default_headers())
126+
127+
return connection
118128

119-
self.headers = {
129+
def _default_headers(self):
130+
return {
131+
'Authorization': 'Bearer %s' % self.api_key,
120132
'Accept': 'application/json',
121133
'Content-Type': 'application/json',
122-
'Authorization': 'Bearer %s' % api_key,
123134
'X-User-Agent': f"python-sdk {SDK_VERSION}",
124135
'X-Python-Version': f"{python_version_info()}",
125136
}
126-
self._data_row_metadata_ontology = None
127-
self._adv_client = AdvClient.factory(rest_endpoint, api_key)
128137

129138
@retry.Retry(predicate=retry.if_exception_type(
130139
labelbox.exceptions.InternalServerError,
@@ -193,18 +202,19 @@ def convert_value(value):
193202
"/graphql", "/_gql")
194203

195204
try:
196-
request = {
197-
'url': endpoint,
198-
'data': data,
199-
'headers': self.headers,
200-
'timeout': timeout
201-
}
205+
headers = self._connection.headers.copy()
202206
if files:
203-
request.update({'files': files})
204-
request['headers'] = {
205-
'Authorization': self.headers['Authorization']
206-
}
207-
response = requests.post(**request)
207+
del headers['Content-Type']
208+
del headers['Accept']
209+
request = requests.Request('POST',
210+
endpoint,
211+
headers=headers,
212+
data=data,
213+
files=files if files else None)
214+
215+
prepped: requests.PreparedRequest = request.prepare()
216+
217+
response = self._connection.send(prepped, timeout=timeout)
208218
logger.debug("Response: %s", response.text)
209219
except requests.exceptions.Timeout as e:
210220
raise labelbox.exceptions.TimeoutError(str(e))
@@ -409,14 +419,21 @@ def upload_data(self,
409419
"map": (None, json.dumps({"1": ["variables.file"]})),
410420
}
411421

412-
response = requests.post(
413-
self.endpoint,
414-
headers={"authorization": "Bearer %s" % self.api_key},
415-
data=request_data,
416-
files={
417-
"1": (filename, content, content_type) if
418-
(filename and content_type) else content
419-
})
422+
files = {
423+
"1": (filename, content, content_type) if
424+
(filename and content_type) else content
425+
}
426+
headers = self._connection.headers.copy()
427+
headers.pop("Content-Type", None)
428+
request = requests.Request('POST',
429+
self.endpoint,
430+
headers=headers,
431+
data=request_data,
432+
files=files)
433+
434+
prepped: requests.PreparedRequest = request.prepare()
435+
436+
response = self._connection.send(prepped)
420437

421438
if response.status_code == 502:
422439
error_502 = '502 Bad Gateway'
@@ -1085,6 +1102,7 @@ def get_feature_schema(self, feature_schema_id):
10851102
query_str = """query rootSchemaNodePyApi($rootSchemaNodeWhere: RootSchemaNodeWhere!){
10861103
rootSchemaNode(where: $rootSchemaNodeWhere){%s}
10871104
}""" % query.results_query_part(Entity.FeatureSchema)
1105+
10881106
res = self.execute(
10891107
query_str,
10901108
{'rootSchemaNodeWhere': {
@@ -1195,10 +1213,7 @@ def delete_unused_feature_schema(self, feature_schema_id: str) -> None:
11951213

11961214
endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote(
11971215
feature_schema_id)
1198-
response = requests.delete(
1199-
endpoint,
1200-
headers=self.headers,
1201-
)
1216+
response = self._connection.delete(endpoint)
12021217

12031218
if response.status_code != requests.codes.no_content:
12041219
raise labelbox.exceptions.LabelboxError(
@@ -1215,10 +1230,7 @@ def delete_unused_ontology(self, ontology_id: str) -> None:
12151230
"""
12161231
endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote(
12171232
ontology_id)
1218-
response = requests.delete(
1219-
endpoint,
1220-
headers=self.headers,
1221-
)
1233+
response = self._connection.delete(endpoint)
12221234

12231235
if response.status_code != requests.codes.no_content:
12241236
raise labelbox.exceptions.LabelboxError(
@@ -1240,11 +1252,7 @@ def update_feature_schema_title(self, feature_schema_id: str,
12401252

12411253
endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote(
12421254
feature_schema_id) + '/definition'
1243-
response = requests.patch(
1244-
endpoint,
1245-
headers=self.headers,
1246-
json={"title": title},
1247-
)
1255+
response = self._connection.patch(endpoint, json={"title": title})
12481256

12491257
if response.status_code == requests.codes.ok:
12501258
return self.get_feature_schema(feature_schema_id)
@@ -1273,11 +1281,8 @@ def upsert_feature_schema(self, feature_schema: Dict) -> FeatureSchema:
12731281
"featureSchemaId") or "new_feature_schema_id"
12741282
endpoint = self.rest_endpoint + "/feature-schemas/" + urllib.parse.quote(
12751283
feature_schema_id)
1276-
response = requests.put(
1277-
endpoint,
1278-
headers=self.headers,
1279-
json={"normalized": json.dumps(feature_schema)},
1280-
)
1284+
response = self._connection.put(
1285+
endpoint, json={"normalized": json.dumps(feature_schema)})
12811286

12821287
if response.status_code == requests.codes.ok:
12831288
return self.get_feature_schema(response.json()['schemaId'])
@@ -1303,11 +1308,7 @@ def insert_feature_schema_into_ontology(self, feature_schema_id: str,
13031308
endpoint = self.rest_endpoint + '/ontologies/' + urllib.parse.quote(
13041309
ontology_id) + "/feature-schemas/" + urllib.parse.quote(
13051310
feature_schema_id)
1306-
response = requests.post(
1307-
endpoint,
1308-
headers=self.headers,
1309-
json={"position": position},
1310-
)
1311+
response = self._connection.post(endpoint, json={"position": position})
13111312
if response.status_code != requests.codes.created:
13121313
raise labelbox.exceptions.LabelboxError(
13131314
"Failed to insert the feature schema into the ontology, message: "
@@ -1328,11 +1329,7 @@ def get_unused_ontologies(self, after: str = None) -> List[str]:
13281329
"""
13291330

13301331
endpoint = self.rest_endpoint + "/ontologies/unused"
1331-
response = requests.get(
1332-
endpoint,
1333-
headers=self.headers,
1334-
json={"after": after},
1335-
)
1332+
response = self._connection.get(endpoint, json={"after": after})
13361333

13371334
if response.status_code == requests.codes.ok:
13381335
return response.json()
@@ -1356,11 +1353,7 @@ def get_unused_feature_schemas(self, after: str = None) -> List[str]:
13561353
"""
13571354

13581355
endpoint = self.rest_endpoint + "/feature-schemas/unused"
1359-
response = requests.get(
1360-
endpoint,
1361-
headers=self.headers,
1362-
json={"after": after},
1363-
)
1356+
response = self._connection.get(endpoint, json={"after": after})
13641357

13651358
if response.status_code == requests.codes.ok:
13661359
return response.json()
@@ -1881,10 +1874,7 @@ def is_feature_schema_archived(self, ontology_id: str,
18811874

18821875
ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote(
18831876
ontology_id)
1884-
response = requests.get(
1885-
ontology_endpoint,
1886-
headers=self.headers,
1887-
)
1877+
response = self._connection.get(ontology_endpoint)
18881878

18891879
if response.status_code == requests.codes.ok:
18901880
feature_schema_nodes = response.json()['featureSchemaNodes']
@@ -1960,10 +1950,7 @@ def delete_feature_schema_from_ontology(
19601950
ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote(
19611951
ontology_id) + "/feature-schemas/" + urllib.parse.quote(
19621952
feature_schema_id)
1963-
response = requests.delete(
1964-
ontology_endpoint,
1965-
headers=self.headers,
1966-
)
1953+
response = self._connection.delete(ontology_endpoint)
19671954

19681955
if response.status_code == requests.codes.ok:
19691956
response_json = response.json()
@@ -1997,10 +1984,7 @@ def unarchive_feature_schema_node(self, ontology_id: str,
19971984
ontology_endpoint = self.rest_endpoint + "/ontologies/" + urllib.parse.quote(
19981985
ontology_id) + '/feature-schemas/' + urllib.parse.quote(
19991986
root_feature_schema_id) + '/unarchive'
2000-
response = requests.patch(
2001-
ontology_endpoint,
2002-
headers=self.headers,
2003-
)
1987+
response = self._connection.patch(ontology_endpoint)
20041988
if response.status_code == requests.codes.ok:
20051989
if not bool(response.json()['unarchived']):
20061990
raise labelbox.exceptions.LabelboxError(

libs/labelbox/tests/integration/test_filtering.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def project_to_test_where(client, rand_gen):
2424

2525
# Avoid assertions using equality to prevent intermittent failures due to
2626
# other builds simultaneously adding projects to test org
27+
@pytest.mark.skip(
28+
reason="broken due to get_projects HF for sunset-custom-editor")
2729
def test_where(client, project_to_test_where):
2830
p_a, p_b, p_c = project_to_test_where
2931
p_a_name = p_a.name

libs/labelbox/tests/integration/test_labeler_performance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55

66
@pytest.mark.skipif(
7-
condition=os.environ['LABELBOX_TEST_ENVIRON'] == "onprem",
8-
reason="longer runtime than expected for onprem. unskip when resolved.")
7+
condition=os.environ['LABELBOX_TEST_ENVIRON'] != "prod",
8+
reason="only works for prod")
99
def test_labeler_performance(configured_project_with_label):
1010
project, _, _, _ = configured_project_with_label
1111

libs/labelbox/tests/integration/test_labeling_frontend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from labelbox import LabelingFrontend
2+
import pytest
23

34

45
def test_get_labeling_frontends(client):
@@ -7,6 +8,8 @@ def test_get_labeling_frontends(client):
78
assert len(filtered_frontends)
89

910

11+
@pytest.mark.skip(
12+
reason="broken due to get_projects HF for sunset-custom-editor")
1013
def test_labeling_frontend_connecting_to_project(project):
1114
client = project.client
1215
default_labeling_frontend = next(

libs/labelbox/tests/integration/test_project.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def delete_tag(tag_id: str):
119119
delete_tag(tagB.uid)
120120

121121

122+
@pytest.mark.skip(
123+
reason="broken due to get_projects HF for sunset-custom-editor")
122124
def test_project_filtering(client, rand_gen, data_for_project_test):
123125
name_1 = rand_gen(str)
124126
p1 = data_for_project_test(name_1)
@@ -144,7 +146,6 @@ def test_attach_instructions(client, project):
144146
assert str(
145147
execinfo.value
146148
) == "Cannot attach instructions to a project that has not been set up."
147-
148149
editor = list(
149150
client.get_labeling_frontends(
150151
where=LabelingFrontend.name == "editor"))[0]
@@ -218,7 +219,7 @@ def test_create_batch_with_global_keys_sync(project: Project, data_rows):
218219
global_keys = [dr.global_key for dr in data_rows]
219220
batch_name = f'batch {uuid.uuid4()}'
220221
batch = project.create_batch(batch_name, global_keys=global_keys)
221-
222+
222223
assert batch.size == len(set(data_rows))
223224

224225

@@ -227,7 +228,7 @@ def test_create_batch_with_global_keys_async(project: Project, data_rows):
227228
global_keys = [dr.global_key for dr in data_rows]
228229
batch_name = f'batch {uuid.uuid4()}'
229230
batch = project._create_batch_async(batch_name, global_keys=global_keys)
230-
231+
231232
assert batch.size == len(set(data_rows))
232233

233234

@@ -245,8 +246,7 @@ def test_media_type(client, project: Project, rand_gen):
245246
# Exclude LLM media types for now, as they are not supported
246247
if MediaType[media_type] in [
247248
MediaType.LLMPromptCreation,
248-
MediaType.LLMPromptResponseCreation,
249-
MediaType.LLM
249+
MediaType.LLMPromptResponseCreation, MediaType.LLM
250250
]:
251251
continue
252252

@@ -284,7 +284,8 @@ def test_label_count(client, configured_batch_project_with_label):
284284

285285
def test_clone(client, project, rand_gen):
286286
# cannot clone unknown project media type
287-
project = client.create_project(name=rand_gen(str), media_type=MediaType.Image)
287+
project = client.create_project(name=rand_gen(str),
288+
media_type=MediaType.Image)
288289
cloned_project = project.clone()
289290

290291
assert cloned_project.description == project.description
@@ -295,4 +296,4 @@ def test_clone(client, project, rand_gen):
295296
assert cloned_project.get_label_count() == 0
296297

297298
project.delete()
298-
cloned_project.delete()
299+
cloned_project.delete()

0 commit comments

Comments
 (0)