diff --git a/servicex_codegen/code_generator.py b/servicex_codegen/code_generator.py index 5de07cc..753065c 100644 --- a/servicex_codegen/code_generator.py +++ b/servicex_codegen/code_generator.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 , IRIS-HEP +# Copyright (c) 2022-2025, IRIS-HEP # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -32,9 +32,15 @@ # modification, are permitted provided that the following conditions are met: # from abc import ABC, abstractmethod -from collections import namedtuple +from dataclasses import dataclass +from typing import Optional -GeneratedFileResult = namedtuple('GeneratedFileResult', 'hash output_dir') + +@dataclass +class GeneratedFileResult: + hash: str + output_dir: str + image: Optional[str] = None class GenerateCodeException(BaseException): @@ -47,5 +53,5 @@ def __init__(self, message: str): class CodeGenerator(ABC): @abstractmethod - def generate_code(self, query, cache_path: str): + def generate_code(self, query, cache_path: str) -> GeneratedFileResult: pass diff --git a/servicex_codegen/post_operation.py b/servicex_codegen/post_operation.py index 0d300dd..af9f399 100644 --- a/servicex_codegen/post_operation.py +++ b/servicex_codegen/post_operation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022 , IRIS-HEP +# Copyright (c) 2022-2025, IRIS-HEP # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -97,7 +97,9 @@ def post(self): zip_data = self.stream_generated_code(generated_code_result) # code gen transformer returns the default transformer image mentioned in # the config file - transformer_image = current_app.config['TRANSFORMER_SCIENCE_IMAGE'] + transformer_image = (generated_code_result.image + if generated_code_result.image is not None + else current_app.config['TRANSFORMER_SCIENCE_IMAGE']) # MultipartEncoder library takes multiple types of data fields and merge # them into a multipart mime data type diff --git a/tests/test_post_operation.py b/tests/test_post_operation.py index 6869184..e413e2f 100644 --- a/tests/test_post_operation.py +++ b/tests/test_post_operation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019, IRIS-HEP +# Copyright (c) 2019-2025, IRIS-HEP # All rights reserved. # # Redistribution and use in source and binary forms, with or without @@ -76,6 +76,7 @@ def test_post_good_query_with_params(self, mocker): select_stmt = "(call ResultTTree (call Select (call SelectMany (call EventDataset (list 'localds://did_01')" # noqa: E501 + # Note we will ignore any image sent in the request! response = client.post("/servicex/generated-code", json={ "transformer_image": "sslhep/servicex_func_adl_xaod_transformer:develop", "code": select_stmt @@ -92,6 +93,7 @@ def test_post_good_query_with_params(self, mocker): print("Zip File: ", zip_file) assert response.status_code == 200 + assert transformer_image == 'foo/bar:latest' check_zip_file(zip_file, 2) # Capture the temporary directory that was generated cache_dir = mock_ast_translator.generate_code.call_args[1]['cache_path'] @@ -137,6 +139,53 @@ def test_post_good_query_without_params(self, mocker): print("Zip File: ", zip_file) assert response.status_code == 200 + assert transformer_image == 'sslhep/servicex_func_adl_xaod_transformer:develop' + check_zip_file(zip_file, 2) + # Capture the temporary directory that was generated + cache_dir = mock_ast_translator.generate_code.call_args[1]['cache_path'] + mock_ast_translator.generate_code.assert_called_with(select_stmt, + cache_path=cache_dir) + + def test_post_good_query_with_params_and_image(self, mocker): + """Produce code for a simple good query""" + + with TemporaryDirectory() as tempdir, \ + open(os.path.join(tempdir, "baz.txt"), 'w'), \ + open(os.path.join(tempdir, "foo.txt"), 'w'): + + mock_ast_translator = mocker.Mock() + mock_ast_translator.generate_code = mocker.Mock( + return_value=GeneratedFileResult(hash="1234", output_dir=tempdir, + image='sslhep/servicex_func_adl_xaod_transformer:develop') + ) + + config = { + 'TARGET_BACKEND': 'uproot', + 'TRANSFORMER_SCIENCE_IMAGE': "foo/bar:latest" + } + app = create_app(config, provided_translator=mock_ast_translator) + client = app.test_client() + + print(app.config) + + select_stmt = "(call ResultTTree (call Select (call SelectMany (call EventDataset (list 'localds://did_01')" # noqa: E501 + + response = client.post("/servicex/generated-code", json={ + "code": select_stmt + }) + + boundary = response.data[2:34].decode('utf-8') + content_type = f"multipart/form-data; boundary={boundary}" + decoder_parts = decoder.MultipartDecoder(response.data, content_type) + + transformer_image = str(decoder_parts.parts[0].content, 'utf-8') + zip_file = decoder_parts.parts[3].content + + print("Transformer Image: ", transformer_image) + print("Zip File: ", zip_file) + + assert response.status_code == 200 + assert transformer_image == 'sslhep/servicex_func_adl_xaod_transformer:develop' check_zip_file(zip_file, 2) # Capture the temporary directory that was generated cache_dir = mock_ast_translator.generate_code.call_args[1]['cache_path']