Skip to content

Commit 1dd7264

Browse files
authored
Fix duplicate cache entries, added CacheAction inside ScriptAutomation (#171)
* Fix typo in clean-nvidia-scratch-space * CacheAction object added inside ScriptAction, fix duplicate cache entries being created for git repos
1 parent 7f1550a commit 1dd7264

File tree

3 files changed

+50
-41
lines changed

3 files changed

+50
-41
lines changed

.github/workflows/test-nvidia-mlperf-inference-implementations.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
strategy:
1818
fail-fast: false
1919
matrix:
20-
system: [ "GO-spr", "phoenix", "GO-i9"]
20+
system: [ "GO-spr", "phoenix"]
2121
# system: [ "mlc-server" ]
2222
python-version: [ "3.12" ]
2323
model: [ "resnet50", "retinanet", "bert-99", "bert-99.9", "gptj-99.9", "3d-unet-99.9", "sdxl" ]

automation/script/module.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313

1414
from mlc.main import Automation
15+
from mlc.main import CacheAction
1516
import mlc.utils as utils
1617
from utils import *
1718

@@ -39,6 +40,8 @@ def __init__(self, action_object, automation_file):
3940
self.logger = self.action_object.logger
4041
self.logger.propagate = False
4142

43+
# Create CacheAction using the same parent as the Script
44+
self.cache_action = CacheAction(self.action_object.parent)
4245
self.tmp_file_env = 'tmp-env'
4346
self.tmp_file_env_all = 'tmp-env-all'
4447
self.tmp_file_run = 'tmp-run'
@@ -351,7 +354,6 @@ def _run(self, i):
351354
skip_cache = i.get('skip_cache', False)
352355
force_cache = i.get('force_cache', False)
353356

354-
fake_run = i.get('fake_run', False)
355357
fake_run = i.get(
356358
'fake_run',
357359
False) if 'fake_run' in i else i.get(
@@ -670,7 +672,7 @@ def _run(self, i):
670672
search_cache = {'action': 'search',
671673
'target_name': 'cache',
672674
'tags': cache_tags_without_tmp_string}
673-
rc = self.action_object.access(search_cache)
675+
rc = self.cache_action.access(search_cache)
674676
if rc['return'] > 0:
675677
return rc
676678

@@ -1383,19 +1385,18 @@ def _run(self, i):
13831385
' - Creating new "cache" script artifact in the MLC local repository ...')
13841386
logger.debug(recursion_spaces +
13851387
' - Tags: {}'.format(','.join(tmp_tags)))
1386-
13871388
if version != '':
13881389
cached_meta['version'] = version
13891390

13901391
ii = {'action': 'update',
1391-
'automation': self.meta['deps']['cache'],
1392+
'target': 'cache',
13921393
'search_tags': tmp_tags,
13931394
'script_alias': meta['alias'],
13941395
'tags': ','.join(tmp_tags),
13951396
'meta': cached_meta,
13961397
'force': True}
13971398

1398-
r = self.action_object.access(ii)
1399+
r = self.cache_action.access(ii)
13991400
if r['return'] > 0:
14001401
return r
14011402

@@ -1495,8 +1496,9 @@ def _run(self, i):
14951496
if r['return'] > 0:
14961497
return r
14971498

1498-
if 'version-' + version not in cached_tags:
1499-
cached_tags.append('version-' + version)
1499+
r = get_version_tag_from_version(version, cached_tags)
1500+
if r['return'] > 0:
1501+
return r
15001502

15011503
if default_version in versions:
15021504
versions_meta = versions[default_version]
@@ -1828,10 +1830,13 @@ def _run(self, i):
18281830

18291831
# If return version
18301832
if cache:
1831-
if r.get('version', '') != '':
1833+
version = r.get('version', '')
1834+
if version != '':
18321835
cached_tags = [
18331836
x for x in cached_tags if not x.startswith('version-')]
1834-
cached_tags.append('version-' + r['version'])
1837+
r = get_version_tag_from_version(version, cached_tags)
1838+
if r['return'] > 0:
1839+
return r
18351840

18361841
if len(r.get('add_extra_cache_tags', [])) > 0:
18371842
for t in r['add_extra_cache_tags']:
@@ -1873,9 +1878,14 @@ def _run(self, i):
18731878
if r.get('version', '') != '':
18741879
version = r.get('version')
18751880
if cache:
1876-
cached_tags = [
1877-
x for x in cached_tags if not x.startswith('version-')]
1878-
cached_tags.append('version-' + r['version'])
1881+
version = r.get('version', '')
1882+
if version != '':
1883+
cached_tags = [
1884+
x for x in cached_tags if not x.startswith('version-')]
1885+
r = get_version_tag_from_version(
1886+
version, cached_tags)
1887+
if r['return'] > 0:
1888+
return r
18791889

18801890
if len(r.get('add_extra_cache_tags', [])) > 0 and cache:
18811891
for t in r['add_extra_cache_tags']:
@@ -2034,14 +2044,14 @@ def _run(self, i):
20342044
cached_meta['dependent_cached_path'] = dependent_cached_path
20352045

20362046
ii = {'action': 'update',
2037-
'automation': self.meta['deps']['cache'],
2047+
'target': 'cache',
20382048
'uid': cached_uid,
20392049
'meta': cached_meta,
20402050
'script_alias': meta['alias'],
20412051
'replace_lists': True, # To replace tags
20422052
'tags': ','.join(cached_tags)}
20432053

2044-
r = self.action_object.access(ii)
2054+
r = self.cache_action.access(ii)
20452055
if r['return'] > 0:
20462056
return r
20472057

@@ -4757,7 +4767,20 @@ def clean_some_tmp_files(self, i):
47574767
return {'return': 0}
47584768

47594769

4770+
def get_version_tag_from_version(version, cached_tags):
4771+
tags_to_add = []
4772+
if version != '':
4773+
if 'version-' + version not in cached_tags:
4774+
cached_tags.append('version-' + version)
4775+
if '-git-' in version:
4776+
version_without_git_commit = version.split("-git-")[0]
4777+
if 'version-' + version_without_git_commit not in cached_tags:
4778+
cached_tags.append('version-' + version_without_git_commit)
4779+
return {'return': 0}
4780+
47604781
##############################################################################
4782+
4783+
47614784
def find_cached_script(i):
47624785
"""
47634786
Internal automation function: find cached script
@@ -4867,11 +4890,12 @@ def find_cached_script(i):
48674890
recursion_spaces +
48684891
' - Prepared variations: {}'.format(variation_tags_string))
48694892

4870-
# Add version
4871-
if version != '':
4872-
if 'version-' + version not in cached_tags:
4873-
cached_tags.append('version-' + version)
4874-
explicit_cached_tags.append('version-' + version)
4893+
r = get_version_tag_from_version(version, cached_tags)
4894+
if r['return'] > 0:
4895+
return r
4896+
get_version_tag_from_version(version, explicit_cached_tags)
4897+
if r['return'] > 0:
4898+
return r
48754899

48764900
# Add extra cache tags (such as "virtual" for python)
48774901
if len(extra_cache_tags) > 0:
@@ -4905,9 +4929,9 @@ def find_cached_script(i):
49054929
recursion_spaces +
49064930
' - Searching for cached script outputs with the following tags: {}'.format(search_tags))
49074931

4908-
r = self_obj.action_object.access({'action': 'search',
4909-
'target_name': 'cache',
4910-
'tags': search_tags})
4932+
r = self_obj.cache_action.access({'action': 'search',
4933+
'target_name': 'cache',
4934+
'tags': search_tags})
49114935
if r['return'] > 0:
49124936
return r
49134937

@@ -4986,21 +5010,6 @@ def find_cached_script(i):
49865010
if r['return'] > 0:
49875011
return r
49885012

4989-
# Check if pre-process and detect
4990-
# if 'preprocess' in dir(customize_code):
4991-
4992-
# logger.debug(recursion_spaces + ' - Running preprocess ...')
4993-
4994-
# ii = copy.deepcopy(customize_common_input)
4995-
# ii['env'] = env
4996-
# ii['meta'] = meta
4997-
# # may need to detect versions in multiple paths
4998-
# ii['run_script_input'] = run_script_input
4999-
5000-
# r = customize_code.preprocess(ii)
5001-
# if r['return'] > 0:
5002-
# return r
5003-
50045013
ii = {
50055014
'run_script_input': run_script_input,
50065015
'env': env,

script/clean-nvidia-mlperf-inference-scratch-space/customize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ def preprocess(i):
3030
cache_rm_tags = "nvidia-harness,_download_model,_sdxl"
3131

3232
cache_rm_tags = cache_rm_tags + extra_cache_rm_tags
33-
mlc = i['automation'].action_object
33+
mlc_cache = i['automation'].cache_action
3434

3535
if cache_rm_tags:
36-
r = mlc.access({'action': 'rm', 'automation': 'cache',
37-
'tags': cache_rm_tags, 'f': True})
36+
r = mlc_cache.access({'action': 'rm', 'target': 'cache',
37+
'tags': cache_rm_tags, 'f': True})
3838
print(r)
3939
if r['return'] != 0 and r['return'] != 16: # ignore missing ones
4040
return r

0 commit comments

Comments
 (0)