Skip to content

Commit 76c7874

Browse files
author
Billy Cheung
committed
To test a PR of the autorope/donkeycar repo:
https://github.com/autorope/donkeycar/pull/782/files (Add datastore metadata function and test autorope#782)
1 parent 5dcc43a commit 76c7874

File tree

5 files changed

+238
-72
lines changed

5 files changed

+238
-72
lines changed

donkeycar/management/base.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,79 @@ def run(self, args):
453453
else:
454454
print("Unrecognized framework: {}. Please specify one of 'tensorflow' or 'pytorch'".format(framework))
455455

456+
# train remotely =====================================================================================================
457+
import requests
458+
import tempfile
459+
import tarfile
460+
from pathlib import Path
461+
from requests_toolbelt.multipart.encoder import MultipartEncoder
462+
463+
class TrainRemote(BaseCommand):
464+
465+
@staticmethod
466+
def generate_tub_archive(self, tub_paths, carapp_path):
467+
print("generating tub archive")
468+
f = tempfile.NamedTemporaryFile(mode='w+b', suffix='.tar.gz', delete=False)
469+
470+
with tarfile.open(fileobj=f, mode='w:gz') as tar:
471+
for tub_path in tub_paths:
472+
p = Path(tub_path)
473+
tar.add(p, arcname=p.name)
474+
tar.add(f"{carapp_path}/myconfig.py", arcname="myconfig.py")
475+
f.close()
476+
477+
return f.name
478+
479+
@staticmethod
480+
def submit_train_job(self, carapp_path, submit_job_url, tub_paths):
481+
filename = self.generate_tub_archive(tub_paths, carapp_path)
482+
deviceID = "device_id"
483+
hostname = "hostname"
484+
485+
mp_encoder = MultipartEncoder(
486+
fields={
487+
'device_id': deviceID,
488+
'hostname': hostname,
489+
'tub_archive_file': ('file.tar.gz', open(filename, 'rb'), 'application/gzip'),
490+
}
491+
)
492+
493+
r = requests.post(
494+
submit_job_url,
495+
data=mp_encoder, # The MultipartEncoder is posted as data, don't use files=...!
496+
# The MultipartEncoder provides the content-type header with the boundary:
497+
headers={'Content-Type': mp_encoder.content_type}
498+
)
499+
500+
if (r.status_code == 200):
501+
# if HTTP 200 OK
502+
if ("job_uuid" in r.json()):
503+
try:
504+
print(r.json()['job_uuid'])
505+
except Exception as e:
506+
print(e)
507+
raise Exception("Failed to call submit job")
508+
else:
509+
raise Exception("Failed to call submit job")
510+
else:
511+
raise Exception("Failed to call submit job")
512+
513+
def parse_args(self, args):
514+
parser = argparse.ArgumentParser(prog='train', usage='%(prog)s [options]')
515+
parser.add_argument('--tub', nargs='+', help='tub data for training')
516+
parser.add_argument('--server', default=None, help='url of the training server')
517+
parser.add_argument('--carpath', default='.', help='path of mycar folder')
518+
parsed_args = parser.parse_args(args)
519+
return parsed_args
520+
521+
def run(self, args):
522+
args = self.parse_args(args)
523+
cfg = load_config(args.config)
524+
if args.server :
525+
self.submit_train_job(args.carpath, args.server, args.tub)
526+
else:
527+
self.submit_train_job(args.carpath, "https://hq.robocarstore.com/train/submit_job", args.tub)
528+
# =====================================================================================================
456529

457530
def execute_from_command_line():
458531
"""
@@ -469,6 +542,7 @@ def execute_from_command_line():
469542
'cnnactivations': ShowCnnActivations,
470543
'update': UpdateCar,
471544
'train': Train,
545+
'trainremote': TrainRemote,
472546
}
473547

474548
args = sys.argv[:]

donkeycar/parts/datastore_v2.py

Lines changed: 113 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,22 @@
1010

1111

1212
class Seekable(object):
13-
'''
14-
A seekable file reader, writer which deals with newline delimited records. \n
15-
This reader maintains an index of line lengths, so seeking a line is a O(1) operation.
16-
'''
13+
"""
14+
A seekable file reader, writer which deals with newline delimited
15+
records. \n
16+
This reader maintains an index of line lengths, so seeking a line is a
17+
O(1) operation.
18+
"""
1719

1820
def __init__(self, file, read_only=False, line_lengths=list()):
1921
self.line_lengths = list()
2022
self.cumulative_lengths = list()
2123
self.method = 'r' if read_only else 'a+'
2224
self.file = open(file, self.method, newline=NEWLINE)
25+
# If file is read only improve performance by memory mapping the file.
2326
if self.method == 'r':
24-
# If file is read only improve performance by memory mappping the file.
25-
self.file = mmap.mmap(self.file.fileno(), length=0, access=mmap.ACCESS_READ)
27+
self.file = mmap.mmap(self.file.fileno(), length=0,
28+
access=mmap.ACCESS_READ)
2629
self.total_length = 0
2730
if len(line_lengths) <= 0:
2831
self._read_contents()
@@ -74,7 +77,8 @@ def _line_end_offset(self, line_number):
7477

7578
def _offset_until(self, line_index):
7679
end_index = line_index - 1
77-
return self.cumulative_lengths[end_index] if end_index >= 0 and end_index < len(self.cumulative_lengths) else 0
80+
return self.cumulative_lengths[end_index] \
81+
if 0 <= end_index < len(self.cumulative_lengths) else 0
7882

7983
def readline(self):
8084
contents = self.file.readline()
@@ -92,7 +96,8 @@ def seek_end_of_file(self):
9296
def truncate_until_end(self, line_number):
9397
self.line_lengths = self.line_lengths[:line_number]
9498
self.cumulative_lengths = self.cumulative_lengths[:line_number]
95-
self.total_length = self.cumulative_lengths[-1] if len(self.cumulative_lengths) > 0 else 0
99+
self.total_length = self.cumulative_lengths[-1] \
100+
if len(self.cumulative_lengths) > 0 else 0
96101
self.seek_end_of_file()
97102
self.file.truncate()
98103

@@ -133,15 +138,18 @@ def __exit__(self, type, value, traceback):
133138
class Catalog(object):
134139
'''
135140
A new line delimited file that has records delimited by newlines. \n
136-
137141
[ json object record ] \n
138142
[ json object record ] \n
139143
...
140144
'''
141145
def __init__(self, path, read_only=False, start_index=0):
142146
self.path = Path(os.path.expanduser(path))
143-
self.manifest = CatalogMetadata(self.path, read_only=read_only, start_index=start_index)
144-
self.seekable = Seekable(self.path.as_posix(), line_lengths=self.manifest.line_lengths(), read_only=read_only)
147+
self.manifest = CatalogMetadata(self.path,
148+
read_only=read_only,
149+
start_index=start_index)
150+
self.seekable = Seekable(self.path.as_posix(),
151+
line_lengths=self.manifest.line_lengths(),
152+
read_only=read_only)
145153

146154
def _exit_handler(self):
147155
self.close()
@@ -164,8 +172,9 @@ class CatalogMetadata(object):
164172
'''
165173
def __init__(self, catalog_path, read_only=False, start_index=0):
166174
path = Path(catalog_path)
167-
manifest_name = '%s.catalog_manifest' % (path.stem)
168-
self.manifest_path = Path(os.path.join(path.parent.as_posix(), manifest_name))
175+
manifest_name = f'{path.stem}.catalog_manifest'
176+
self.manifest_path = Path(os.path.join(path.parent.as_posix(),
177+
manifest_name))
169178
self.seekeable = Seekable(self.manifest_path, read_only=read_only)
170179
has_contents = False
171180
if os.path.exists(self.manifest_path) and self.seekeable.has_content():
@@ -207,7 +216,6 @@ def close(self):
207216
class Manifest(object):
208217
'''
209218
A newline delimited file, with the following format.
210-
211219
[ json array of inputs ]\n
212220
[ json array of types ]\n
213221
[ json object with user metadata ]\n
@@ -230,50 +238,75 @@ def __init__(self, base_path, inputs=[], types=[], metadata=[],
230238
self.catalog_paths = list()
231239
self.catalog_metadata = dict()
232240
self.deleted_indexes = set()
241+
self._updated_session = False
233242
has_catalogs = False
234243

235244
if self.manifest_path.exists():
236245
self.seekeable = Seekable(self.manifest_path, read_only=self.read_only)
237246
if self.seekeable.has_content():
238247
self._read_contents()
239248
has_catalogs = len(self.catalog_paths) > 0
249+
240250
else:
241251
created_at = time.time()
242252
self.manifest_metadata['created_at'] = created_at
243253
if not self.base_path.exists():
244254
self.base_path.mkdir(parents=True, exist_ok=True)
245-
print('Created a new datastore at %s' % (self.base_path.as_posix()))
255+
print(f'Created a new datastore at {self.base_path.as_posix()}')
246256
self.seekeable = Seekable(self.manifest_path, read_only=self.read_only)
247257

248258
if not has_catalogs:
249259
self._write_contents()
250260
self._add_catalog()
251261
else:
252-
last_known_catalog = os.path.join(self.base_path, self.catalog_paths[-1])
253-
print('Using catalog %s' % (last_known_catalog))
254-
self.current_catalog = Catalog(last_known_catalog, read_only=self.read_only, start_index=self.current_index)
262+
last_known_catalog = os.path.join(self.base_path,
263+
self.catalog_paths[-1])
264+
print(f'Using catalog {last_known_catalog}')
265+
self.current_catalog = Catalog(last_known_catalog,
266+
read_only=self.read_only,
267+
start_index=self.current_index)
268+
# Create a new session_id, which will be added to each record in the
269+
# tub, when Tub.write_record() is called.
270+
self.session_id = self.create_new_session()
255271

256272
def write_record(self, record):
257-
new_catalog = self.current_index > 0 and (self.current_index % self.max_len) == 0
273+
new_catalog = self.current_index > 0 \
274+
and (self.current_index % self.max_len) == 0
258275
if new_catalog:
259276
self._add_catalog()
260277

261278
self.current_catalog.write_record(record)
262279
self.current_index += 1
263280
# Update metadata to keep track of the last index
264281
self._update_catalog_metadata(update=True)
282+
# Set session_id update status to True if this method is called at
283+
# least once. Then session id metadata will be updated when the
284+
# session gets closed
285+
if not self._updated_session:
286+
self._updated_session = True
265287

266288
def delete_record(self, record_index):
267289
# Does not actually delete the record, but marks it as deleted.
268290
self.deleted_indexes.add(record_index)
269291
self._update_catalog_metadata(update=True)
270292

293+
def update_metadata(self, metadata):
294+
self.metadata = {**self.metadata, **metadata}
295+
self._write_contents()
296+
297+
def restore_record(self, record_index):
298+
# Does not actually delete the record, but marks it as deleted.
299+
self.deleted_indexes.discard(record_index)
300+
self._update_catalog_metadata(update=True)
301+
271302
def _add_catalog(self):
272303
current_length = len(self.catalog_paths)
273-
catalog_name = 'catalog_%s.catalog' % (current_length)
304+
catalog_name = f'catalog_{current_length}.catalog'
274305
catalog_path = os.path.join(self.base_path, catalog_name)
275306
current_catalog = self.current_catalog
276-
self.current_catalog = Catalog(catalog_path, start_index=self.current_index, read_only=self.read_only)
307+
self.current_catalog = Catalog(catalog_path,
308+
start_index=self.current_index,
309+
read_only=self.read_only)
277310
# Store relative paths
278311
self.catalog_paths.append(catalog_name)
279312
self._update_catalog_metadata(update=True)
@@ -318,7 +351,30 @@ def _update_catalog_metadata(self, update=True):
318351
self.catalog_metadata = catalog_metadata
319352
self.seekeable.writeline(json.dumps(catalog_metadata))
320353

354+
def create_new_session(self):
355+
""" Creates a new session id and appends it to the metadata."""
356+
sessions = self.manifest_metadata.get('sessions', {})
357+
last_id = -1
358+
if sessions:
359+
last_id = sessions['last_id']
360+
else:
361+
sessions['all_full_ids'] = []
362+
this_id = last_id + 1
363+
date = time.strftime('%y-%m-%d')
364+
this_full_id = date + '_' + str(this_id)
365+
sessions['last_id'] = this_id
366+
sessions['last_full_id'] = this_full_id
367+
sessions['all_full_ids'].append(this_full_id)
368+
self.manifest_metadata['sessions'] = sessions
369+
return this_full_id
370+
321371
def close(self):
372+
""" Closing tub closes open files for catalog, catalog manifest and
373+
manifest.json"""
374+
# If records were received, write updated session_id dictionary into
375+
# the metadata, otherwise keep the session_id information unchanged
376+
if self._updated_session:
377+
self.seekeable.update_line(4, json.dumps(self.manifest_metadata))
322378
self.current_catalog.close()
323379
self.seekeable.close()
324380

@@ -331,11 +387,10 @@ def __len__(self):
331387

332388

333389
class ManifestIterator(object):
334-
'''
390+
"""
335391
An iterator for the Manifest type. \n
336-
337392
Returns catalog entries lazily when a consumer calls __next__().
338-
'''
393+
"""
339394
def __init__(self, manifest):
340395
self.manifest = manifest
341396
self.has_catalogs = len(self.manifest.catalog_paths) > 0
@@ -344,39 +399,42 @@ def __init__(self, manifest):
344399
self.current_catalog = None
345400

346401
def __next__(self):
347-
if not self.has_catalogs:
348-
raise StopIteration('No catalogs')
349-
350-
if self.current_catalog_index >= len(self.manifest.catalog_paths):
351-
raise StopIteration('No more catalogs')
352-
353-
if self.current_catalog is None:
354-
current_catalog_path = os.path.join(self.manifest.base_path, self.manifest.catalog_paths[self.current_catalog_index])
355-
self.current_catalog = Catalog(current_catalog_path, read_only=self.manifest.read_only)
356-
self.current_catalog.seekable.seek_line_start(1)
357-
358-
contents = self.current_catalog.seekable.readline()
359-
360-
if contents is not None and len(contents) > 0:
361-
# Check for current_index when we are ready to advance the underlying iterator.
362-
current_index = self.current_index
363-
self.current_index += 1
364-
if current_index in self.manifest.deleted_indexes:
365-
# Skip over index, because it has been marked deleted
366-
return self.__next__()
402+
while True:
403+
if not self.has_catalogs:
404+
raise StopIteration('No catalogs')
405+
406+
if self.current_catalog_index >= len(self.manifest.catalog_paths):
407+
raise StopIteration('No more catalogs')
408+
409+
if self.current_catalog is None:
410+
current_catalog_path = os.path.join(
411+
self.manifest.base_path,
412+
self.manifest.catalog_paths[self.current_catalog_index])
413+
self.current_catalog = Catalog(current_catalog_path,
414+
read_only=self.manifest.read_only)
415+
self.current_catalog.seekable.seek_line_start(1)
416+
417+
contents = self.current_catalog.seekable.readline()
418+
if contents is not None and len(contents) > 0:
419+
# Check for current_index when we are ready to advance the
420+
# underlying iterator.
421+
current_index = self.current_index
422+
self.current_index += 1
423+
if current_index in self.manifest.deleted_indexes:
424+
# Skip over index, because it has been marked deleted
425+
continue
426+
else:
427+
try:
428+
record = json.loads(contents)
429+
return record
430+
except Exception:
431+
print(f'Ignoring record at index {current_index}')
432+
continue
367433
else:
368-
try:
369-
record = json.loads(contents)
370-
return record
371-
except Exception:
372-
print('Ignoring record at index %s' % (current_index))
373-
return self.__next__()
374-
else:
375-
self.current_catalog = None
376-
self.current_catalog_index += 1
377-
return self.__next__()
434+
self.current_catalog = None
435+
self.current_catalog_index += 1
378436

379437
next = __next__
380438

381439
def __len__(self):
382-
return self.manifest.__len__()
440+
return self.manifest.__len__()

0 commit comments

Comments
 (0)