Skip to content

Commit f3acf03

Browse files
authored
[bug-fix] LLM Artifact Gateway .list_files() (#416)
* fix test to catch use case * fix parsing for prefixes
1 parent ad6b764 commit f3acf03

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

model-engine/model_engine_server/core/utils/url.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class InvalidAttachmentUrl(ValueError):
3232
pass
3333

3434

35-
def parse_attachment_url(url: str) -> ParsedURL:
35+
def parse_attachment_url(url: str, clean_key: bool = True) -> ParsedURL:
3636
"""Extracts protocol, bucket, region, and key from the :param:`url`.
3737
3838
:raises: InvalidAttachmentUrl Iff the input `url` is not a valid AWS S3 or GCS url.
@@ -102,5 +102,5 @@ def clean(v):
102102
protocol=clean(protocol),
103103
bucket=clean(bucket),
104104
region=clean(region),
105-
key=clean(key),
105+
key=clean(key) if clean_key else key,
106106
)

model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def _get_s3_resource(self, kwargs):
2323

2424
def list_files(self, path: str, **kwargs) -> List[str]:
2525
s3 = self._get_s3_resource(kwargs)
26-
parsed_remote = parse_attachment_url(path)
26+
parsed_remote = parse_attachment_url(path, clean_key=False)
2727
bucket = parsed_remote.bucket
2828
key = parsed_remote.key
2929

@@ -33,7 +33,7 @@ def list_files(self, path: str, **kwargs) -> List[str]:
3333

3434
def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]:
3535
s3 = self._get_s3_resource(kwargs)
36-
parsed_remote = parse_attachment_url(path)
36+
parsed_remote = parse_attachment_url(path, clean_key=False)
3737
bucket = parsed_remote.bucket
3838
key = parsed_remote.key
3939

@@ -58,7 +58,9 @@ def download_files(self, path: str, target_path: str, overwrite=False, **kwargs)
5858

5959
def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]:
6060
s3 = self._get_s3_resource(kwargs)
61-
parsed_remote = parse_attachment_url(hmi_config.hf_user_fine_tuned_weights_prefix)
61+
parsed_remote = parse_attachment_url(
62+
hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False
63+
)
6264
bucket = parsed_remote.bucket
6365
fine_tuned_weights_prefix = parsed_remote.key
6466

model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def llm_artifact_gateway():
1414

1515
@pytest.fixture
1616
def fake_files():
17-
return ["fake-prefix/fake1", "fake-prefix/fake2", "fake-prefix/fake3"]
17+
return ["fake-prefix/fake1", "fake-prefix/fake2", "fake-prefix/fake3", "fake-prefix-ext/fake1"]
1818

1919

2020
def mock_boto3_session(fake_files: List[str]):
@@ -39,11 +39,13 @@ def filter_files(*args, **kwargs):
3939
lambda *args, **kwargs: None, # noqa
4040
)
4141
def test_s3_llm_artifact_gateway_download_folder(llm_artifact_gateway, fake_files):
42-
prefix = "/".join(fake_files[0].split("/")[:-1])
42+
prefix = "/".join(fake_files[0].split("/")[:-1]) + "/"
4343
uri_prefix = f"s3://fake-bucket/{prefix}"
4444
target_dir = "fake-target"
4545

46-
expected_files = [f"{target_dir}/{file.split('/')[-1]}" for file in fake_files]
46+
expected_files = [
47+
f"{target_dir}/{file.split('/')[-1]}" for file in fake_files if file.startswith(prefix)
48+
]
4749
with mock.patch(
4850
"model_engine_server.infra.gateways.s3_llm_artifact_gateway.boto3.Session",
4951
mock_boto3_session(fake_files),

0 commit comments

Comments
 (0)