Skip to content

Commit c00d2ba

Browse files
committed
Streaming model download
1 parent e017b64 commit c00d2ba

File tree

3 files changed

+77
-75
lines changed

3 files changed

+77
-75
lines changed

backend/substrapp/tasks/tasks.py

Lines changed: 49 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
from celery.task import Task
2222

2323
from backend.celery import app
24-
from substrapp.utils import get_hash, get_owner, create_directory, uncompress_content, compute_hash
24+
from substrapp.utils import get_hash, get_owner, create_directory, uncompress_content
2525
from substrapp.ledger_utils import (log_start_tuple, log_success_tuple, log_fail_tuple,
2626
query_tuples, LedgerError, LedgerStatusError, get_object_from_ledger)
27-
from substrapp.tasks.utils import (ResourcesManager, compute_docker, get_asset_content, list_files,
28-
get_k8s_client, timeit)
27+
from substrapp.tasks.utils import (ResourcesManager, compute_docker, get_asset_content, get_and_put_asset_content,
28+
list_files, get_k8s_client, timeit)
2929
from substrapp.tasks.exception_handler import compute_error_code
3030

3131
logger = logging.getLogger(__name__)
@@ -136,52 +136,61 @@ def find_training_step_tuple_from_key(tuple_key):
136136
f'Key {tuple_key}: no tuple found for training step: model: {metadata}')
137137

138138

139-
def get_model_content(tuple_type, tuple_key, tuple_, out_model):
139+
def get_and_put_model_content(tuple_type, tuple_key, tuple_, out_model, model_dst_path):
140140
"""Get out model content."""
141141
owner = tuple_get_owner(tuple_type, tuple_)
142-
return get_asset_content(
142+
return get_and_put_asset_content(
143143
out_model['storageAddress'],
144144
owner,
145145
out_model['hash'],
146146
salt=tuple_key,
147+
content_dst_path=model_dst_path
147148
)
148149

149150

150-
def get_local_model_content(tuple_key, out_model):
151+
def get_and_put_local_model_content(tuple_key, out_model, model_dst_path):
151152
"""Get local model content."""
152153
from substrapp.models import Model
153154

154155
model = Model.objects.get(pk=out_model['hash'])
155-
model_content = model.file.read()
156-
computed_hash = compute_hash(model_content, key=tuple_key)
157156

158-
if computed_hash != out_model['hash']:
159-
raise Exception(f"Local model fetch error: hash doesn't match {out_model['hash']} vs {computed_hash}")
157+
# verify that local db model file is not corrupted
158+
if get_hash(model.file.path, tuple_key) != out_model['hash']:
159+
raise Exception('Local Model Hash in Subtuple is not the same as in local db')
160160

161-
return model_content
161+
if not os.path.exists(model_dst_path):
162+
os.link(model.file.path, model_dst_path)
163+
else:
164+
# verify that local subtuple model file is not corrupted
165+
if get_hash(model_dst_path, tuple_key) != out_model['hash']:
166+
raise Exception('Local Model Hash in Subtuple is not the same as in local medias')
162167

163168

164169
@timeit
165-
def get_and_put_model_content(parent_tuple_type, authorized_types, input_model, directory):
170+
def fetch_model(parent_tuple_type, authorized_types, input_model, directory):
166171

167172
tuple_type, metadata = find_training_step_tuple_from_key(input_model['traintupleKey'])
168173

169174
if tuple_type not in authorized_types:
170175
raise TasksError(f'{parent_tuple_type.capitalize()}: invalid input model: type={tuple_type}')
171176

177+
model_dst_path = path.join(directory, f'model/{input_model["traintupleKey"]}')
178+
172179
if tuple_type == TRAINTUPLE_TYPE:
173-
model = get_model_content(tuple_type, input_model['traintupleKey'], metadata, metadata['outModel'])
180+
get_and_put_model_content(
181+
tuple_type, input_model['traintupleKey'], metadata, metadata['outModel'], model_dst_path
182+
)
174183
elif tuple_type == AGGREGATETUPLE_TYPE:
175-
model = get_model_content(tuple_type, input_model['traintupleKey'], metadata, metadata['outModel'])
184+
get_and_put_model_content(
185+
tuple_type, input_model['traintupleKey'], metadata, metadata['outModel'], model_dst_path
186+
)
176187
elif tuple_type == COMPOSITE_TRAINTUPLE_TYPE:
177-
model = get_model_content(
178-
tuple_type, input_model['traintupleKey'], metadata, metadata['outTrunkModel']['outModel']
188+
get_and_put_model_content(
189+
tuple_type, input_model['traintupleKey'], metadata, metadata['outTrunkModel']['outModel'], model_dst_path
179190
)
180191
else:
181192
raise TasksError(f'Traintuple: invalid input model: type={tuple_type}')
182193

183-
_put_model(directory, model, input_model['hash'], input_model['traintupleKey'])
184-
185194

186195
def prepare_traintuple_input_models(directory, tuple_):
187196
"""Get traintuple input models content."""
@@ -193,7 +202,7 @@ def prepare_traintuple_input_models(directory, tuple_):
193202

194203
models = []
195204
for input_model in input_models:
196-
proc = Thread(target=get_and_put_model_content,
205+
proc = Thread(target=fetch_model,
197206
args=(TRAINTUPLE_TYPE, authorized_types, input_model, directory))
198207
models.append(proc)
199208
proc.start()
@@ -212,7 +221,7 @@ def prepare_aggregatetuple_input_models(directory, tuple_):
212221
models = []
213222

214223
for input_model in input_models:
215-
proc = Thread(target=get_and_put_model_content,
224+
proc = Thread(target=fetch_model,
216225
args=(AGGREGATETUPLE_TYPE, authorized_types, input_model, directory))
217226
models.append(proc)
218227
proc.start()
@@ -235,33 +244,27 @@ def prepare_composite_traintuple_input_models(directory, tuple_):
235244
if tuple_type != COMPOSITE_TRAINTUPLE_TYPE:
236245
raise TasksError(f'CompositeTraintuple: invalid head input model: type={tuple_type}')
237246
# get the output head model
238-
head_model_content = get_local_model_content(head_model_key, metadata['outHeadModel']['outModel'])
247+
head_model_dst_path = path.join(directory, f'model/{PREFIX_HEAD_FILENAME}{head_model_key}')
248+
get_and_put_local_model_content(
249+
head_model_key, metadata['outHeadModel']['outModel'], head_model_dst_path
250+
)
239251

240252
# get trunk model
241253
trunk_model_key = trunk_model['traintupleKey']
242254
tuple_type, metadata = find_training_step_tuple_from_key(trunk_model_key)
255+
trunk_model_dst_path = path.join(directory, f'model/{PREFIX_TRUNK_FILENAME}{trunk_model_key}')
243256
# trunk model must refer to a composite traintuple or an aggregatetuple
244257
if tuple_type == COMPOSITE_TRAINTUPLE_TYPE: # get output trunk model
245-
trunk_model_content = get_model_content(
246-
tuple_type, trunk_model_key, metadata, metadata['outTrunkModel']['outModel'],
258+
get_and_put_model_content(
259+
tuple_type, trunk_model_key, metadata, metadata['outTrunkModel']['outModel'], trunk_model_dst_path
247260
)
248261
elif tuple_type == AGGREGATETUPLE_TYPE:
249-
trunk_model_content = get_model_content(
250-
tuple_type, trunk_model_key, metadata, metadata['outModel'],
262+
get_and_put_model_content(
263+
tuple_type, trunk_model_key, metadata, metadata['outModel'], trunk_model_dst_path
251264
)
252265
else:
253266
raise TasksError(f'CompositeTraintuple: invalid trunk input model: type={tuple_type}')
254267

255-
# put head and trunk models
256-
_put_model(directory, head_model_content,
257-
tuple_['inHeadModel']['hash'],
258-
tuple_['inHeadModel']['traintupleKey'],
259-
filename_prefix=PREFIX_HEAD_FILENAME)
260-
_put_model(directory, trunk_model_content,
261-
tuple_['inTrunkModel']['hash'],
262-
tuple_['inTrunkModel']['traintupleKey'],
263-
filename_prefix=PREFIX_TRUNK_FILENAME)
264-
265268

266269
def prepare_testtuple_input_models(directory, tuple_):
267270
"""Get testtuple input models content."""
@@ -272,52 +275,26 @@ def prepare_testtuple_input_models(directory, tuple_):
272275

273276
if traintuple_type == TRAINTUPLE_TYPE:
274277
metadata = get_object_from_ledger(traintuple_key, 'queryTraintuple')
275-
model = get_model_content(traintuple_type, traintuple_key, metadata, metadata['outModel'])
276-
model_hash = metadata['outModel']['hash']
277-
_put_model(directory, model, model_hash, traintuple_key)
278+
model_dst_path = path.join(directory, f'model/{traintuple_key}')
279+
get_and_put_model_content(
280+
traintuple_type, traintuple_key, metadata, metadata['outModel'], model_dst_path
281+
)
278282

279283
elif traintuple_type == COMPOSITE_TRAINTUPLE_TYPE:
280284
metadata = get_object_from_ledger(traintuple_key, 'queryCompositeTraintuple')
281-
head_content = get_local_model_content(traintuple_key, metadata['outHeadModel']['outModel'])
282-
trunk_content = get_model_content(
283-
traintuple_type, traintuple_key, metadata, metadata['outTrunkModel']['outModel'],
285+
head_model_dst_path = path.join(directory, f'model/{PREFIX_HEAD_FILENAME}{traintuple_key}')
286+
get_and_put_local_model_content(traintuple_key, metadata['outHeadModel']['outModel'],
287+
head_model_dst_path)
288+
289+
model_dst_path = path.join(directory, f'model/{PREFIX_TRUNK_FILENAME}{traintuple_key}')
290+
get_and_put_model_content(
291+
traintuple_type, traintuple_key, metadata, metadata['outTrunkModel']['outModel'], model_dst_path
284292
)
285-
_put_model(directory, head_content, metadata['outHeadModel']['outModel']['hash'],
286-
traintuple_key, filename_prefix=PREFIX_HEAD_FILENAME)
287-
_put_model(directory, trunk_content, metadata['outTrunkModel']['outModel']['hash'],
288-
traintuple_key, filename_prefix=PREFIX_TRUNK_FILENAME)
289293

290294
else:
291295
raise TasksError(f"Testtuple from type '{traintuple_type}' not supported")
292296

293297

294-
def _put_model(subtuple_directory, model_content, model_hash, traintuple_hash, filename_prefix=''):
295-
if not model_content:
296-
raise Exception('Model content should not be empty')
297-
298-
from substrapp.models import Model
299-
300-
# store a model in local subtuple directory from input model content
301-
model_dst_path = path.join(subtuple_directory, f'model/{filename_prefix}{traintuple_hash}')
302-
model = None
303-
try:
304-
model = Model.objects.get(pk=model_hash)
305-
except ObjectDoesNotExist: # write it to local disk
306-
with open(model_dst_path, 'wb') as f:
307-
f.write(model_content)
308-
else:
309-
# verify that local db model file is not corrupted
310-
if get_hash(model.file.path, traintuple_hash) != model_hash:
311-
raise Exception('Model Hash in Subtuple is not the same as in local db')
312-
313-
if not os.path.exists(model_dst_path):
314-
os.link(model.file.path, model_dst_path)
315-
else:
316-
# verify that local subtuple model file is not corrupted
317-
if get_hash(model_dst_path, traintuple_hash) != model_hash:
318-
raise Exception('Model Hash in Subtuple is not the same as in local medias')
319-
320-
321298
@timeit
322299
def prepare_models(directory, tuple_type, tuple_):
323300
"""Prepare models for tuple execution.

backend/substrapp/tasks/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from subprocess import check_output
88
from django.conf import settings
99
from requests.auth import HTTPBasicAuth
10-
from substrapp.utils import get_owner, get_remote_file_content, NodeError
10+
from substrapp.utils import get_owner, get_remote_file_content, get_and_put_remote_file_content, NodeError
1111

1212
from kubernetes import client, config
1313

@@ -47,6 +47,11 @@ def get_asset_content(url, node_id, content_hash, salt=None):
4747
return get_remote_file_content(url, authenticate_worker(node_id), content_hash, salt=salt)
4848

4949

50+
def get_and_put_asset_content(url, node_id, content_hash, salt=None, content_dst_path=None):
51+
return get_and_put_remote_file_content(url, authenticate_worker(node_id), content_hash, salt=salt,
52+
content_dst_path=content_dst_path)
53+
54+
5055
@timeit
5156
def get_cpu_count(client):
5257
# Get CPU count from docker container through the API

backend/substrapp/utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class NodeError(Exception):
158158
pass
159159

160160

161-
def get_remote_file(url, auth, **kwargs):
161+
def get_remote_file(url, auth, content_dst_path=None, **kwargs):
162162
kwargs.update({
163163
'headers': {'Accept': 'application/json;version=0.0'},
164164
'auth': auth
@@ -168,14 +168,21 @@ def get_remote_file(url, auth, **kwargs):
168168
kwargs['verify'] = False
169169

170170
try:
171-
response = requests.get(url, **kwargs)
171+
if kwargs.get('stream', False) and content_dst_path is not None:
172+
chunk_size = 1024 * 1024
173+
with open(content_dst_path, 'wb') as fp:
174+
response = requests.get(url, **kwargs)
175+
fp.writelines(response.iter_content(chunk_size))
176+
else:
177+
response = requests.get(url, **kwargs)
172178
except (requests.exceptions.ConnectionError, requests.exceptions.Timeout) as e:
173179
raise NodeError(f'Failed to fetch {url}') from e
174180

175181
return response
176182

177183

178184
def get_remote_file_content(url, auth, content_hash, salt=None):
185+
179186
response = get_remote_file(url, auth)
180187

181188
if response.status_code != status.HTTP_200_OK:
@@ -186,3 +193,16 @@ def get_remote_file_content(url, auth, content_hash, salt=None):
186193
if computed_hash != content_hash:
187194
raise NodeError(f"url {url}: hash doesn't match {content_hash} vs {computed_hash}")
188195
return response.content
196+
197+
198+
def get_and_put_remote_file_content(url, auth, content_hash, salt=None, content_dst_path=None):
199+
200+
response = get_remote_file(url, auth, content_dst_path, stream=True)
201+
202+
if response.status_code != status.HTTP_200_OK:
203+
logger.error(response.text)
204+
raise NodeError(f'Url: {url} returned status code: {response.status_code}')
205+
206+
computed_hash = get_hash(content_dst_path, key=salt)
207+
if computed_hash != content_hash:
208+
raise NodeError(f"url {url}: hash doesn't match {content_hash} vs {computed_hash}")

0 commit comments

Comments
 (0)