diff --git a/examples/models/llama3_2_vision/preprocess/export_preprocess.py b/examples/models/llama3_2_vision/preprocess/export_preprocess.py index 58c79095074..d82f79c2f35 100644 --- a/examples/models/llama3_2_vision/preprocess/export_preprocess.py +++ b/examples/models/llama3_2_vision/preprocess/export_preprocess.py @@ -5,28 +5,47 @@ # LICENSE file in the root directory of this source tree. import torch -from executorch.examples.models.llama3_2_vision.preprocess.export_preprocess_lib import ( - export_preprocess, - get_example_inputs, - lower_to_executorch_preprocess, +from executorch.examples.models.llama3_2_vision.preprocess.model import ( + CLIPImageTransformModel, + PreprocessConfig, ) +from executorch.exir import EdgeCompileConfig, to_edge def main(): + # Eager model. + model = CLIPImageTransformModel(PreprocessConfig()) - # ExecuTorch - ep_et = export_preprocess() - et = lower_to_executorch_preprocess(ep_et) - with open("preprocess_et.pte", "wb") as file: - et.write_to_file(file) - - # AOTInductor - ep_aoti = export_preprocess() - torch._inductor.aot_compile( - ep_aoti.module(), - get_example_inputs(), - options={"aot_inductor.output_path": "preprocess_aoti.so"}, + # Export. + ep = torch.export.export( + model.get_eager_model(), + model.get_example_inputs(), + dynamic_shapes=model.get_dynamic_shapes(), + strict=False, + ) + + # Executorch + edge_program = to_edge( + ep, compile_config=EdgeCompileConfig(_check_ir_validity=False) ) + et_program = edge_program.to_executorch() + with open("preprocess_et.pte", "wb") as file: + et_program.write_to_file(file) + + # Export. + # ep = torch.export.export( + # model.get_eager_model(), + # model.get_example_inputs(), + # dynamic_shapes=model.get_dynamic_shapes(), + # strict=False, + # ) + # + # # AOTInductor + # torch._inductor.aot_compile( + # ep.module(), + # model.get_example_inputs(), + # options={"aot_inductor.output_path": "preprocess_aoti.so"}, + # ) if __name__ == "__main__": diff --git a/examples/models/llama3_2_vision/preprocess/export_preprocess_lib.py b/examples/models/llama3_2_vision/preprocess/export_preprocess_lib.py deleted file mode 100644 index f3fe8188c04..00000000000 --- a/examples/models/llama3_2_vision/preprocess/export_preprocess_lib.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Dict, List, Optional, Tuple - -import torch -from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge -from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass -from executorch.exir.program._program import ExecutorchProgramManager - -from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa - -from torch.export import Dim, ExportedProgram -from torchtune.models.clip.inference._transform import _CLIPImageTransform - - -def get_example_inputs() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - image = torch.ones(3, 800, 600) - target_size = torch.tensor([448, 336]) - canvas_size = torch.tensor([448, 448]) - return (image, target_size, canvas_size) - - -def get_dynamic_shapes() -> Dict[str, Dict[int, Dim]]: - img_h = Dim("img_h", min=1, max=4000) - img_w = Dim("img_w", min=1, max=4000) - - dynamic_shapes = { - "image": {1: img_h, 2: img_w}, - "target_size": None, - "canvas_size": None, - } - return dynamic_shapes - - -def export_preprocess( - resample: str = "bilinear", - image_mean: Optional[List[float]] = None, - image_std: Optional[List[float]] = None, - max_num_tiles: int = 4, - tile_size: int = 224, - antialias: bool = False, -) -> ExportedProgram: - - # Instantiate eager model. - image_transform_model = _CLIPImageTransform( - resample=resample, - image_mean=image_mean, - image_std=image_std, - max_num_tiles=max_num_tiles, - tile_size=tile_size, - antialias=antialias, - ) - - # Replace non-exportable ops with custom ops. - image_transform_model.tile_crop = torch.ops.preprocess.tile_crop.default - - # Export. - example_inputs = get_example_inputs() - dynamic_shapes = get_dynamic_shapes() - ep = torch.export.export( - image_transform_model, - example_inputs, - dynamic_shapes=dynamic_shapes, - strict=False, - ) - return ep - - -def lower_to_executorch_preprocess( - exported_program: ExportedProgram, -) -> ExecutorchProgramManager: - edge_program = to_edge( - exported_program, compile_config=EdgeCompileConfig(_check_ir_validity=False) - ) - - et_program = edge_program.to_executorch( - ExecutorchBackendConfig( - sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(), - ) - ) - return et_program diff --git a/examples/models/llama3_2_vision/preprocess/model.py b/examples/models/llama3_2_vision/preprocess/model.py new file mode 100644 index 00000000000..7b3b4869af6 --- /dev/null +++ b/examples/models/llama3_2_vision/preprocess/model.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + +from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa +from torch.export import Dim +from torchtune.models.clip.inference._transform import _CLIPImageTransform + +from ...model_base import EagerModelBase + + +@dataclass +class PreprocessConfig: + image_mean: Optional[List[float]] = None + image_std: Optional[List[float]] = None + resample: str = "bilinear" + max_num_tiles: int = 4 + tile_size: int = 224 + antialias: bool = False + + +class CLIPImageTransformModel(EagerModelBase): + def __init__( + self, + config: PreprocessConfig, + ): + super().__init__() + + # Eager model. + self.model = _CLIPImageTransform( + image_mean=config.image_mean, + image_std=config.image_std, + resample=config.resample, + max_num_tiles=config.max_num_tiles, + tile_size=config.tile_size, + antialias=config.antialias, + ) + + # Replace non-exportable ops with custom ops. + self.model.tile_crop = torch.ops.preprocess.tile_crop.default + + def get_eager_model(self) -> torch.nn.Module: + return self.model + + def get_example_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + image = torch.ones(3, 800, 600) + target_size = torch.tensor([448, 336]) + canvas_size = torch.tensor([448, 448]) + return (image, target_size, canvas_size) + + def get_dynamic_shapes(self) -> Dict[str, Dict[int, Dim]]: + img_h = Dim("img_h", min=1, max=4000) + img_w = Dim("img_w", min=1, max=4000) + + dynamic_shapes = { + "image": {1: img_h, 2: img_w}, + "target_size": None, + "canvas_size": None, + } + return dynamic_shapes