From eeb76482cde2a7caf364083ab61d2a26b727f166 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Mon, 3 Mar 2025 16:45:09 -0800 Subject: [PATCH 01/11] Remove libOMP lnking for experimental kernels --- torchao/experimental/Utils.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/experimental/Utils.cmake b/torchao/experimental/Utils.cmake index 984c90006b..57a35e61fa 100644 --- a/torchao/experimental/Utils.cmake +++ b/torchao/experimental/Utils.cmake @@ -21,7 +21,6 @@ function(target_link_torchao_parallel_backend target_name torchao_parallel_backe target_link_libraries(${target_name} PRIVATE "${TORCH_LIBRARIES}") target_compile_definitions(${target_name} PRIVATE TORCHAO_PARALLEL_ATEN=1 AT_PARALLEL_OPENMP=1 INTRA_OP_PARALLEL=1) - target_link_libraries(${target_name} PRIVATE ${TORCH_INSTALL_PREFIX}/lib/libomp${CMAKE_SHARED_LIBRARY_SUFFIX}) elseif(TORCHAO_PARALLEL_BACKEND_TOUPPER STREQUAL "EXECUTORCH") message(STATUS "Building with TORCHAO_PARALLEL_BACKEND=TORCHAO_PARALLEL_EXECUTORCH") From 6bf760eca5d6d7948a43f62437fa83278a63d3fb Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 4 Mar 2025 10:23:01 -0800 Subject: [PATCH 02/11] Install nightly instead of pinned version --- .github/workflows/torchao_experimental_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index e1511ffe9a..fab898c074 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -33,7 +33,7 @@ jobs: - name: Install requirements run: | conda activate venv - pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" torch=="2.6.0.dev20250104" + pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" pip install numpy pip install pytest USE_CPP=1 pip install . From e122fd55489f5adf1e907ad5d1b71aa568706109 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 4 Mar 2025 10:48:26 -0800 Subject: [PATCH 03/11] Update torchao_experimental_test.yml --- .github/workflows/torchao_experimental_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index fab898c074..0646eebc03 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -33,7 +33,7 @@ jobs: - name: Install requirements run: | conda activate venv - pip install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" + pip install torch --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" pip install numpy pip install pytest USE_CPP=1 pip install . From a48dbe3f20e1e91f48ce9d20d7d685383e479fed Mon Sep 17 00:00:00 2001 From: ngc92 <7938269+ngc92@users.noreply.github.com> Date: Tue, 4 Mar 2025 02:06:06 +0100 Subject: [PATCH 04/11] CPUOffload: only offload parameters above a certain size (#1720) * CPUOffload: only offload parameters above a certain size * lint * ruff --------- Co-authored-by: Mark Saroufim --- test/prototype/test_low_bit_optim.py | 12 +++-- .../prototype/low_bit_optim/cpu_offload.py | 48 +++++++++++++++++-- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 453210abda..deaead873b 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -273,11 +273,11 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): model1 = nn.Sequential( nn.Linear(32, 131072), nn.ReLU(), - nn.Linear(131072, 64), + nn.Linear(131072, 64, bias=True), nn.ReLU(), - nn.Linear(64, 64), + nn.Linear(64, 64, bias=True), nn.ReLU(), - nn.Linear(64, 128), + nn.Linear(64, 128, bias=True), ) model1.to(device) @@ -329,7 +329,11 @@ def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): ) def test_optim_cpu_offload_save_load(self): device = _DEVICES[-1] - model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)) + # enable bias parameters so we have some small tensors that + # are not offloaded. + model1 = nn.Sequential( + nn.Linear(32, 1024, bias=True), nn.ReLU(), nn.Linear(1024, 128, bias=True) + ) model1.to(device) optim1 = low_bit_optim.CPUOffloadOptimizer( model1.parameters(), torch.optim.AdamW diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py index b94340a32a..61e4077d1d 100644 --- a/torchao/prototype/low_bit_optim/cpu_offload.py +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -17,6 +17,7 @@ def __init__( optimizer_class: Type[Optimizer] = torch.optim.AdamW, *, offload_gradients: bool = False, + minimal_size: int = 4096, **kwargs, ) -> None: """Offload optimizer to CPU for single-GPU training. This will reduce GPU memory by the size of optimizer state. @@ -26,6 +27,7 @@ def __init__( params: a list of parameters or parameter groups. optimizer_class: constructor of the base optimizer. Defaults to :class:`torch.optim.AdamW`. offload_gradients: free GPU gradients once they are moved to CPU. Not compatible with gradient accumulation. + minimal_size: tensors smaller than this are kept on the GPU, to avoid excessively many small transfers. kwargs: other keyword arguments to be passed to the base optimizer e.g. `lr`, `weight_decay`. """ # default to fused CPU AdamW @@ -42,6 +44,11 @@ def __init__( if not isinstance(param_groups[0], dict): param_groups = [{"params": param_groups}] + # any parameter smaller than minimal size will be handled by the on-device optimizer d_opt + self.minimal_size = minimal_size + self.d_opt = None + self.d_param_groups = [] + self.param_d2h_map = dict() self.optim_dict = dict() self.device = get_available_devices()[-1] @@ -77,11 +84,16 @@ def backward_hook(p_device): for param_group in param_groups: params = param_group.pop("params") + retained_params = [] for p_device in params: if not p_device.requires_grad: continue + if p_device.numel() < self.minimal_size: + retained_params.append(p_device) + continue + # pre-allocate CPU params and grads p_host = torch.empty_like(p_device, device="cpu", pin_memory=True) p_host.grad = torch.empty_like(p_host, pin_memory=True) @@ -94,12 +106,22 @@ def backward_hook(p_device): [{"params": p_host, **param_group}], **kwargs ) + if len(retained_params) > 0: + self.d_param_groups.append({"params": retained_params, **param_group}) + + if len(self.d_param_groups) > 0: + self.d_opt = optimizer_class(self.d_param_groups, **kwargs) + @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: loss = closure() + # handle small parameters on the GPU, in parallel with the CPU calls below + if self.d_opt is not None: + self.d_opt.step() + for p_device, grad_d2h_event in self.queue.items(): grad_d2h_event.synchronize() self.optim_dict[p_device].step() @@ -123,15 +145,35 @@ def zero_grad(self, set_to_none=True): for p_device in self.param_d2h_map.keys(): p_device.grad = None + if self.d_opt is not None: + self.d_opt.zero_grad(set_to_none=set_to_none) + @property def param_groups(self): # each param group will only has 1 parameter # TODO: we might want to return the original param_groups instead. - return sum((optim.param_groups for optim in self.optim_dict.values()), start=[]) + return sum( + (optim.param_groups for optim in self.optim_dict.values()), + start=self.d_param_groups, + ) def state_dict(self): - return [optim.state_dict() for optim in self.optim_dict.values()] + state_dict = { + "offloaded": [optim.state_dict() for optim in self.optim_dict.values()] + } + if self.d_opt: + state_dict["on-device"] = self.d_opt.state_dict() + return state_dict def load_state_dict(self, state_dict): - for optim, optim_state_dict in zip(self.optim_dict.values(), state_dict): + for optim, optim_state_dict in zip( + self.optim_dict.values(), state_dict["offloaded"] + ): optim.load_state_dict(optim_state_dict) + + if self.d_opt: + self.d_opt.load_state_dict(state_dict["on-device"]) + elif "on-device" in state_dict: + raise ValueError( + "loaded state dict has a 'on-device' parameter group not present in the optimizer" + ) From 6e988bbf11809bb32baa0632704882cbce10095f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 4 Mar 2025 12:59:22 +0900 Subject: [PATCH 05/11] update typehint (#1740) * update typehint Signed-off-by: Masaki Kozuki * Update float8_linear_utils.py --------- Signed-off-by: Masaki Kozuki Co-authored-by: Mark Saroufim --- torchao/float8/float8_linear_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/float8/float8_linear_utils.py b/torchao/float8/float8_linear_utils.py index db9889567f..8ea6e2e23a 100644 --- a/torchao/float8/float8_linear_utils.py +++ b/torchao/float8/float8_linear_utils.py @@ -85,7 +85,7 @@ def convert_to_float8_training( module: nn.Module, *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, - config: Float8LinearConfig = None, + config: Optional[Float8LinearConfig] = None, ) -> nn.Module: """ Swaps `torch.nn.Linear` in `module` with `Float8Linear`. From 51e359a32cdc317360b217385399906df8a258e6 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Mon, 3 Mar 2025 22:07:53 -0800 Subject: [PATCH 06/11] Move torchao/_models to benchmarks/_models (#1784) --- .github/workflows/dashboard_perf_test.yml | 10 +- README.md | 4 +- {torchao/_models => benchmarks}/__init__.py | 0 {torchao => benchmarks}/_models/README.md | 0 .../llama => benchmarks/_models}/__init__.py | 0 {torchao => benchmarks}/_models/_eval.py | 0 .../_models/llama/.gitignore | 0 .../_models/llama/README.md | 2 +- benchmarks/_models/llama/__init__.py | 0 .../_models/llama/benchmark_results.txt | 0 .../_models/llama/benchmarks.sh | 0 .../_models/llama/demo_summarize.sh | 0 {torchao => benchmarks}/_models/llama/eval.py | 21 ++- .../_models/llama/evals.sh | 0 .../_models/llama/generate.py | 134 +++------------ .../_models/llama/model.py | 0 .../_models/llama/perf_profile.py | 4 +- .../_models/llama/tokenizer.py | 0 .../_models/sam/.gitignore | 0 {torchao => benchmarks}/_models/sam/README.md | 0 benchmarks/_models/sam/__init__.py | 0 .../_models/sam/benchmark.sh | 0 {torchao => benchmarks}/_models/sam/data.py | 0 .../_models/sam/eval_combo.py | 2 +- .../_models/sam/flash_4_configs.p | Bin .../_models/sam/metrics.py | 0 .../_models/sam/results.csv | 0 {torchao => benchmarks}/_models/sam/setup.sh | 0 .../_models/sam2/__init__.py | 2 +- .../_models/sam2/automatic_mask_generator.py | 8 +- .../_models/sam2/build_sam.py | 4 +- .../sam2/configs/sam2.1/sam2.1_hiera_b+.yaml | 28 ++-- .../sam2/configs/sam2.1/sam2.1_hiera_l.yaml | 28 ++-- .../sam2/configs/sam2.1/sam2.1_hiera_s.yaml | 28 ++-- .../sam2/configs/sam2.1/sam2.1_hiera_t.yaml | 28 ++-- .../sam2.1_hiera_b+_MOSE_finetune.yaml | 0 .../sam2/configs/sam2/sam2_hiera_b+.yaml | 28 ++-- .../sam2/configs/sam2/sam2_hiera_l.yaml | 28 ++-- .../sam2/configs/sam2/sam2_hiera_s.yaml | 28 ++-- .../sam2/configs/sam2/sam2_hiera_t.yaml | 28 ++-- .../_models/sam2/csrc/connected_components.cu | 0 .../_models/sam2/map_tensor.py | 0 .../_models/sam2/modeling/__init__.py | 0 .../sam2/modeling/backbones/__init__.py | 0 .../sam2/modeling/backbones/hieradet.py | 4 +- .../sam2/modeling/backbones/image_encoder.py | 2 +- .../_models/sam2/modeling/backbones/utils.py | 0 .../_models/sam2/modeling/memory_attention.py | 4 +- .../_models/sam2/modeling/memory_encoder.py | 6 +- .../sam2/modeling/position_encoding.py | 0 .../_models/sam2/modeling/sam/__init__.py | 0 .../_models/sam2/modeling/sam/mask_decoder.py | 2 +- .../sam2/modeling/sam/prompt_encoder.py | 4 +- .../_models/sam2/modeling/sam/transformer.py | 6 +- .../_models/sam2/modeling/sam2_base.py | 8 +- .../_models/sam2/modeling/sam2_utils.py | 2 +- .../_models/sam2/sam2_hiera_b+.yaml | 0 .../_models/sam2/sam2_hiera_l.yaml | 0 .../_models/sam2/sam2_hiera_s.yaml | 0 .../_models/sam2/sam2_hiera_t.yaml | 0 .../_models/sam2/sam2_image_predictor.py | 6 +- .../_models/sam2/sam2_video_predictor.py | 6 +- .../_models/sam2/utils/__init__.py | 0 .../_models/sam2/utils/amg.py | 0 .../_models/sam2/utils/misc.py | 0 .../_models/sam2/utils/transforms.py | 4 +- {torchao => benchmarks}/_models/utils.py | 89 ++++++++++ .../quantized_training/pretrain_llama2.py | 4 +- docs/source/contributor_guide.rst | 10 +- examples/sam2_amg_server/annotate_with_rle.py | 2 +- examples/sam2_amg_server/cli.py | 6 +- examples/sam2_amg_server/cli_on_modal.py | 8 +- examples/sam2_amg_server/compare_rle_lists.py | 2 +- .../sam2_amg_server/compile_export_utils.py | 12 +- examples/sam2_amg_server/generate_data.py | 10 +- .../sam2_amg_server/result_batch_size_16.csv | 154 +++++++++--------- examples/sam2_amg_server/server.py | 8 +- .../sam2_vos_example/compile_export_utils.py | 2 +- examples/sam2_vos_example/video_profile.py | 4 +- scripts/convert_hf_checkpoint.py | 2 +- test/prototype/test_spinquant.py | 2 +- test/quantization/test_gptq_mt.py | 4 +- test/quantization/test_quant_api.py | 16 +- test/test_ao_models.py | 2 +- torchao/prototype/awq/README.md | 8 +- .../scripts/BO_acc_throughput.py | 16 +- torchao/prototype/spinquant/spinquant.py | 2 +- torchao/quantization/GPTQ.py | 4 +- torchao/quantization/README.md | 12 +- torchao/sparsity/README.md | 2 +- torchao/utils.py | 20 +++ 91 files changed, 445 insertions(+), 425 deletions(-) rename {torchao/_models => benchmarks}/__init__.py (100%) rename {torchao => benchmarks}/_models/README.md (100%) rename {torchao/_models/llama => benchmarks/_models}/__init__.py (100%) rename {torchao => benchmarks}/_models/_eval.py (100%) rename {torchao => benchmarks}/_models/llama/.gitignore (100%) rename {torchao => benchmarks}/_models/llama/README.md (95%) create mode 100644 benchmarks/_models/llama/__init__.py rename {torchao => benchmarks}/_models/llama/benchmark_results.txt (100%) rename {torchao => benchmarks}/_models/llama/benchmarks.sh (100%) rename {torchao => benchmarks}/_models/llama/demo_summarize.sh (100%) rename {torchao => benchmarks}/_models/llama/eval.py (96%) rename {torchao => benchmarks}/_models/llama/evals.sh (100%) rename {torchao => benchmarks}/_models/llama/generate.py (91%) rename {torchao => benchmarks}/_models/llama/model.py (100%) rename {torchao => benchmarks}/_models/llama/perf_profile.py (99%) rename {torchao => benchmarks}/_models/llama/tokenizer.py (100%) rename {torchao => benchmarks}/_models/sam/.gitignore (100%) rename {torchao => benchmarks}/_models/sam/README.md (100%) create mode 100644 benchmarks/_models/sam/__init__.py rename {torchao => benchmarks}/_models/sam/benchmark.sh (100%) rename {torchao => benchmarks}/_models/sam/data.py (100%) rename {torchao => benchmarks}/_models/sam/eval_combo.py (99%) rename {torchao => benchmarks}/_models/sam/flash_4_configs.p (100%) rename {torchao => benchmarks}/_models/sam/metrics.py (100%) rename {torchao => benchmarks}/_models/sam/results.csv (100%) rename {torchao => benchmarks}/_models/sam/setup.sh (100%) rename {torchao => benchmarks}/_models/sam2/__init__.py (81%) rename {torchao => benchmarks}/_models/sam2/automatic_mask_generator.py (99%) rename {torchao => benchmarks}/_models/sam2/build_sam.py (97%) rename {torchao => benchmarks}/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml (71%) rename {torchao => benchmarks}/_models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml (72%) rename {torchao => benchmarks}/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml (72%) rename {torchao => benchmarks}/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml (72%) rename {torchao => benchmarks}/_models/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml (100%) rename {torchao => benchmarks}/_models/sam2/configs/sam2/sam2_hiera_b+.yaml (70%) rename {torchao => benchmarks}/_models/sam2/configs/sam2/sam2_hiera_l.yaml (71%) rename {torchao => benchmarks}/_models/sam2/configs/sam2/sam2_hiera_s.yaml (71%) rename {torchao => benchmarks}/_models/sam2/configs/sam2/sam2_hiera_t.yaml (72%) rename {torchao => benchmarks}/_models/sam2/csrc/connected_components.cu (100%) rename {torchao => benchmarks}/_models/sam2/map_tensor.py (100%) rename {torchao => benchmarks}/_models/sam2/modeling/__init__.py (100%) rename {torchao => benchmarks}/_models/sam2/modeling/backbones/__init__.py (100%) rename {torchao => benchmarks}/_models/sam2/modeling/backbones/hieradet.py (98%) rename {torchao => benchmarks}/_models/sam2/modeling/backbones/image_encoder.py (98%) rename {torchao => benchmarks}/_models/sam2/modeling/backbones/utils.py (100%) rename {torchao => benchmarks}/_models/sam2/modeling/memory_attention.py (97%) rename {torchao => benchmarks}/_models/sam2/modeling/memory_encoder.py (98%) rename {torchao => benchmarks}/_models/sam2/modeling/position_encoding.py (100%) rename {torchao => benchmarks}/_models/sam2/modeling/sam/__init__.py (100%) rename {torchao => benchmarks}/_models/sam2/modeling/sam/mask_decoder.py (99%) rename {torchao => benchmarks}/_models/sam2/modeling/sam/prompt_encoder.py (98%) rename {torchao => benchmarks}/_models/sam2/modeling/sam/transformer.py (98%) rename {torchao => benchmarks}/_models/sam2/modeling/sam2_base.py (99%) rename {torchao => benchmarks}/_models/sam2/modeling/sam2_utils.py (99%) rename {torchao => benchmarks}/_models/sam2/sam2_hiera_b+.yaml (100%) rename {torchao => benchmarks}/_models/sam2/sam2_hiera_l.yaml (100%) rename {torchao => benchmarks}/_models/sam2/sam2_hiera_s.yaml (100%) rename {torchao => benchmarks}/_models/sam2/sam2_hiera_t.yaml (100%) rename {torchao => benchmarks}/_models/sam2/sam2_image_predictor.py (99%) rename {torchao => benchmarks}/_models/sam2/sam2_video_predictor.py (99%) rename {torchao => benchmarks}/_models/sam2/utils/__init__.py (100%) rename {torchao => benchmarks}/_models/sam2/utils/amg.py (100%) rename {torchao => benchmarks}/_models/sam2/utils/misc.py (100%) rename {torchao => benchmarks}/_models/sam2/utils/transforms.py (97%) rename {torchao => benchmarks}/_models/utils.py (54%) diff --git a/.github/workflows/dashboard_perf_test.yml b/.github/workflows/dashboard_perf_test.yml index 81ea40d341..64338aff7a 100644 --- a/.github/workflows/dashboard_perf_test.yml +++ b/.github/workflows/dashboard_perf_test.yml @@ -42,19 +42,19 @@ jobs: mkdir -p ${{ runner.temp }}/benchmark-results # llama3 - compile baseline - ${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json + ${CONDA_RUN} python benchmarks/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json # llama3 - autoquant - ${CONDA_RUN} python torchao/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --quantization autoquant --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json + ${CONDA_RUN} python benchmarks/_models/llama/generate.py --checkpoint_path "${CHECKPOINT_PATH}/${MODEL_REPO}/model.pth" --compile --compile_prefill --quantization autoquant --output_json_path ${{ runner.temp }}/benchmark-results/llama3-benchmark-results.json # skipping SAM because of https://hud.pytorch.org/pr/pytorch/ao/1407 # # SAM # ${CONDA_RUN} pip install git+https://github.com/pytorch-labs/segment-anything-fast.git@main # # SAM compile baselilne - # ${CONDA_RUN} sh torchao/_models/sam/setup.sh - # ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json + # ${CONDA_RUN} sh benchmarks/_models/sam/setup.sh + # ${CONDA_RUN} python benchmarks/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json - # ${CONDA_RUN} python torchao/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --compression autoquant --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json + # ${CONDA_RUN} python benchmarks/_models/sam/eval_combo.py --coco_root_dir datasets/coco2017 --coco_slice_name val2017 --sam_checkpoint_base_path checkpoints --sam_model_type vit_h --point_sampling_cache_dir tmp/sam_coco_mask_center_cache --mask_debug_out_dir tmp/sam_eval_masks_out --batch_size 32 --num_workers 8 --use_compile max-autotune --use_half bfloat16 --device cuda --compression autoquant --output_json_path ${{ runner.temp }}/benchmark-results/sam-benchmark-results.json # SAM 2.1 # ${CONDA_RUN} sh scripts/download_sam2_ckpts.sh ${CHECKPOINT_PATH}/sam2 diff --git a/README.md b/README.md index 606b48986d..a48899e123 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ torchao just works with `torch.compile()` and `FSDP2` over most PyTorch models o ### Post Training Quantization -Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/), sparsity [here](/torchao/_models/sam/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py) +Quantizing and Sparsifying your models is a 1 liner that should work on any model with an `nn.Linear` including your favorite HuggingFace model. You can find a more comprehensive usage instructions [here](torchao/quantization/), sparsity [here](/benchmarks/_models/sam/README.md) and a HuggingFace inference example [here](scripts/hf_eval.py) For inference, we have the option of 1. Quantize only the weights: works best for memory bound models @@ -52,7 +52,7 @@ We also provide a developer facing API so you can implement your own quantizatio We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference. -In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](torchao/_models/llama/README.md) +In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](benchmarks/_models/llama/README.md) ## Training diff --git a/torchao/_models/__init__.py b/benchmarks/__init__.py similarity index 100% rename from torchao/_models/__init__.py rename to benchmarks/__init__.py diff --git a/torchao/_models/README.md b/benchmarks/_models/README.md similarity index 100% rename from torchao/_models/README.md rename to benchmarks/_models/README.md diff --git a/torchao/_models/llama/__init__.py b/benchmarks/_models/__init__.py similarity index 100% rename from torchao/_models/llama/__init__.py rename to benchmarks/_models/__init__.py diff --git a/torchao/_models/_eval.py b/benchmarks/_models/_eval.py similarity index 100% rename from torchao/_models/_eval.py rename to benchmarks/_models/_eval.py diff --git a/torchao/_models/llama/.gitignore b/benchmarks/_models/llama/.gitignore similarity index 100% rename from torchao/_models/llama/.gitignore rename to benchmarks/_models/llama/.gitignore diff --git a/torchao/_models/llama/README.md b/benchmarks/_models/llama/README.md similarity index 95% rename from torchao/_models/llama/README.md rename to benchmarks/_models/llama/README.md index 99f1919fc9..9e1bd2b062 100644 --- a/torchao/_models/llama/README.md +++ b/benchmarks/_models/llama/README.md @@ -8,7 +8,7 @@ and follow the steps to gain access. Then from the torchao root directory use `huggingface-cli login` and follow the steps to login, then `sh ./scripts/prepare.sh` to download and convert the model weights -once done you can execute benchmarks from the torchao/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking or evaluation +once done you can execute benchmarks from the benchmarks/_models/llama dir with `sh benchmarks.sh`. You can perform and benchmarking or evaluation directly using `generate.py` or `eval.py`. ## KV Cache Quantization - Memory Efficient Inference diff --git a/benchmarks/_models/llama/__init__.py b/benchmarks/_models/llama/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/_models/llama/benchmark_results.txt b/benchmarks/_models/llama/benchmark_results.txt similarity index 100% rename from torchao/_models/llama/benchmark_results.txt rename to benchmarks/_models/llama/benchmark_results.txt diff --git a/torchao/_models/llama/benchmarks.sh b/benchmarks/_models/llama/benchmarks.sh similarity index 100% rename from torchao/_models/llama/benchmarks.sh rename to benchmarks/_models/llama/benchmarks.sh diff --git a/torchao/_models/llama/demo_summarize.sh b/benchmarks/_models/llama/demo_summarize.sh similarity index 100% rename from torchao/_models/llama/demo_summarize.sh rename to benchmarks/_models/llama/demo_summarize.sh diff --git a/torchao/_models/llama/eval.py b/benchmarks/_models/llama/eval.py similarity index 96% rename from torchao/_models/llama/eval.py rename to benchmarks/_models/llama/eval.py index 4a67124a08..4c077c92a0 100644 --- a/torchao/_models/llama/eval.py +++ b/benchmarks/_models/llama/eval.py @@ -8,14 +8,13 @@ from typing import List, Optional import torch -from generate import ( - _load_model, - device_sync, -) from tokenizer import get_tokenizer import torchao -from torchao._models.llama.model import prepare_inputs_for_model +from benchmarks._models.llama.model import prepare_inputs_for_model +from benchmarks._models.utils import ( + _load_model, +) from torchao.quantization import ( PerRow, PerTensor, @@ -28,7 +27,11 @@ quantize_, uintx_weight_only, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, unwrap_tensor_subclass +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + device_sync, + unwrap_tensor_subclass, +) def run_evaluation( @@ -120,7 +123,7 @@ def run_evaluation( quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) if "int4wo" in quantization and "gptq" in quantization: # avoid circular imports - from torchao._models._eval import MultiTensorInputRecorder + from benchmarks._models._eval import MultiTensorInputRecorder from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer groupsize = int(quantization.split("-")[-2]) @@ -172,7 +175,7 @@ def run_evaluation( if "autoround" in quantization: from transformers import AutoTokenizer - from torchao._models.llama.model import TransformerBlock + from benchmarks._models.llama.model import TransformerBlock from torchao.prototype.autoround.autoround_llm import ( quantize_model_with_autoround_, ) @@ -242,7 +245,7 @@ def run_evaluation( with torch.no_grad(): print("Running evaluation ...") # avoid circular imports - from torchao._models._eval import TransformerEvalWrapper + from benchmarks._models._eval import TransformerEvalWrapper TransformerEvalWrapper( model=model.to(device), diff --git a/torchao/_models/llama/evals.sh b/benchmarks/_models/llama/evals.sh similarity index 100% rename from torchao/_models/llama/evals.sh rename to benchmarks/_models/llama/evals.sh diff --git a/torchao/_models/llama/generate.py b/benchmarks/_models/llama/generate.py similarity index 91% rename from torchao/_models/llama/generate.py rename to benchmarks/_models/llama/generate.py index 0958a5207c..9f527c31ba 100644 --- a/torchao/_models/llama/generate.py +++ b/benchmarks/_models/llama/generate.py @@ -7,20 +7,30 @@ import time from datetime import datetime from pathlib import Path -from typing import Optional, Tuple +from typing import Optional import torch import torch._dynamo.config import torch._inductor.config import torchao -from torchao._models.utils import ( +from benchmarks._models.utils import ( + _load_model, + decode_n_tokens, + decode_one_token, + encode_tokens, get_arch_name, + prefill, write_json_result_local, write_json_result_ossci, ) from torchao.quantization.quant_primitives import MappingType -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, get_model_size_in_bytes +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_5, + default_device, + device_sync, + get_model_size_in_bytes, +) torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False torch.backends.cuda.enable_cudnn_sdp(True) @@ -49,97 +59,12 @@ def device_timer(device): print(f"device={device} is not yet suppported") -def device_sync(device): - if "cuda" in device: - torch.cuda.synchronize(device) - elif "xpu" in device: - torch.xpu.synchronize(device) - elif ("cpu" in device) or ("mps" in device): - pass - else: - print(f"device={device} is not yet suppported") - - -default_device = ( - "cuda" - if torch.cuda.is_available() - else "xpu" - if torch.xpu.is_available() - else "cpu" -) - # support running without installing as a package wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao._models.llama.tokenizer import get_tokenizer - - -def multinomial_sample_one_no_sync( - probs_sort, -): # Does multinomial sampling without a cuda synchronization - q = torch.empty_like(probs_sort).exponential_(1) - return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) - - -def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): - logits = logits / max(temperature, 1e-5) - - if top_k is not None: - v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - pivot = v.select(-1, -1).unsqueeze(-1) - logits = torch.where(logits < pivot, -float("Inf"), logits) - probs = torch.nn.functional.softmax(logits, dim=-1) - return probs - - -def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): - probs = logits_to_probs(logits[:, -1], temperature, top_k) - idx_next = multinomial_sample_one_no_sync(probs) - return idx_next, probs - - -def prefill( - model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs -) -> torch.Tensor: - # input_pos: [B, S] - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs)[0] - - -def decode_one_token( - model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs -) -> Tuple[torch.Tensor, torch.Tensor]: - # input_pos: [B, 1] - assert input_pos.shape[-1] == 1 - logits = model(x, input_pos) - return sample(logits, **sampling_kwargs) - - -def decode_n_tokens( - model: Transformer, - cur_token: torch.Tensor, - input_pos: torch.Tensor, - num_new_tokens: int, - callback=lambda _: _, - **sampling_kwargs, -): - new_tokens, new_probs = [], [] - for i in range(num_new_tokens): - with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): - next_token, next_prob = decode_one_token( - model, cur_token, input_pos, **sampling_kwargs - ) - next_token, next_prob = next_token.clone(), next_prob.clone() - input_pos += 1 - # in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step - new_tokens.append(next_token.clone()) - callback(new_tokens[-1]) - new_probs.append(next_prob) - cur_token = next_token - - return new_tokens, new_probs +from benchmarks._models.llama.model import Transformer, prepare_inputs_for_model +from benchmarks._models.llama.tokenizer import get_tokenizer def model_forward(model, x, input_pos): @@ -230,25 +155,6 @@ def generate( return seq -def encode_tokens(tokenizer, string, bos=True, device=default_device): - tokens = tokenizer.encode(string) - if bos: - tokens = [tokenizer.bos_id()] + tokens - return torch.tensor(tokens, dtype=torch.int, device=device) - - -def _load_model(checkpoint_path, device, precision): - checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) - if "model" in checkpoint and "stories" in str(checkpoint_path): - checkpoint = checkpoint["model"] - with torch.device("meta"): - model = Transformer.from_name(checkpoint_path.parent.name) - model.load_state_dict(checkpoint, assign=True) - model = model.to(device=device, dtype=precision) - - return model.eval() - - B_INST, E_INST = "[INST]", "[/INST]" @@ -476,7 +382,7 @@ def ffn_or_attn_only(mod, fqn): filter_fn=lambda x, *args: isinstance(x, torch.nn.Embedding), ) elif quantization.startswith("awq"): - from torchao._models._eval import TransformerEvalWrapper + from benchmarks._models._eval import TransformerEvalWrapper from torchao.utils import TORCH_VERSION_AT_LEAST_2_3 if not TORCH_VERSION_AT_LEAST_2_3: @@ -575,8 +481,8 @@ def ffn_or_attn_only(mod, fqn): model, float8_dynamic_activation_float8_weight(granularity=granularity) ) elif "autoquant_v2" in quantization: - from torchao._models._eval import InputRecorder - from torchao._models.llama.model import prepare_inputs_for_model + from benchmarks._models._eval import InputRecorder + from benchmarks._models.llama.model import prepare_inputs_for_model from torchao.prototype.quantization.autoquant_v2 import autoquant_v2 calibration_seq_length = 256 @@ -665,8 +571,8 @@ def ffn_or_attn_only(mod, fqn): # do autoquantization model.finalize_autoquant() elif "autoquant" in quantization: - from torchao._models._eval import InputRecorder - from torchao._models.llama.model import prepare_inputs_for_model + from benchmarks._models._eval import InputRecorder + from benchmarks._models.llama.model import prepare_inputs_for_model calibration_seq_length = 256 inputs = ( diff --git a/torchao/_models/llama/model.py b/benchmarks/_models/llama/model.py similarity index 100% rename from torchao/_models/llama/model.py rename to benchmarks/_models/llama/model.py diff --git a/torchao/_models/llama/perf_profile.py b/benchmarks/_models/llama/perf_profile.py similarity index 99% rename from torchao/_models/llama/perf_profile.py rename to benchmarks/_models/llama/perf_profile.py index f613982221..d1e9cab83c 100644 --- a/torchao/_models/llama/perf_profile.py +++ b/benchmarks/_models/llama/perf_profile.py @@ -116,8 +116,8 @@ import torch from torch.nn.attention import SDPBackend -from torchao._models.llama.model import Transformer -from torchao._models.llama.tokenizer import get_tokenizer +from benchmarks._models.llama.model import Transformer +from benchmarks._models.llama.tokenizer import get_tokenizer from torchao.prototype.profiler import ( CUDADeviceSpec, TransformerPerformanceCounter, diff --git a/torchao/_models/llama/tokenizer.py b/benchmarks/_models/llama/tokenizer.py similarity index 100% rename from torchao/_models/llama/tokenizer.py rename to benchmarks/_models/llama/tokenizer.py diff --git a/torchao/_models/sam/.gitignore b/benchmarks/_models/sam/.gitignore similarity index 100% rename from torchao/_models/sam/.gitignore rename to benchmarks/_models/sam/.gitignore diff --git a/torchao/_models/sam/README.md b/benchmarks/_models/sam/README.md similarity index 100% rename from torchao/_models/sam/README.md rename to benchmarks/_models/sam/README.md diff --git a/benchmarks/_models/sam/__init__.py b/benchmarks/_models/sam/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchao/_models/sam/benchmark.sh b/benchmarks/_models/sam/benchmark.sh similarity index 100% rename from torchao/_models/sam/benchmark.sh rename to benchmarks/_models/sam/benchmark.sh diff --git a/torchao/_models/sam/data.py b/benchmarks/_models/sam/data.py similarity index 100% rename from torchao/_models/sam/data.py rename to benchmarks/_models/sam/data.py diff --git a/torchao/_models/sam/eval_combo.py b/benchmarks/_models/sam/eval_combo.py similarity index 99% rename from torchao/_models/sam/eval_combo.py rename to benchmarks/_models/sam/eval_combo.py index 781c10c935..7f17df4f4f 100644 --- a/torchao/_models/sam/eval_combo.py +++ b/benchmarks/_models/sam/eval_combo.py @@ -9,7 +9,7 @@ from metrics import calculate_miou, create_result_entry import torchao -from torchao._models.utils import ( +from benchmarks._models.utils import ( get_arch_name, write_json_result_local, write_json_result_ossci, diff --git a/torchao/_models/sam/flash_4_configs.p b/benchmarks/_models/sam/flash_4_configs.p similarity index 100% rename from torchao/_models/sam/flash_4_configs.p rename to benchmarks/_models/sam/flash_4_configs.p diff --git a/torchao/_models/sam/metrics.py b/benchmarks/_models/sam/metrics.py similarity index 100% rename from torchao/_models/sam/metrics.py rename to benchmarks/_models/sam/metrics.py diff --git a/torchao/_models/sam/results.csv b/benchmarks/_models/sam/results.csv similarity index 100% rename from torchao/_models/sam/results.csv rename to benchmarks/_models/sam/results.csv diff --git a/torchao/_models/sam/setup.sh b/benchmarks/_models/sam/setup.sh similarity index 100% rename from torchao/_models/sam/setup.sh rename to benchmarks/_models/sam/setup.sh diff --git a/torchao/_models/sam2/__init__.py b/benchmarks/_models/sam2/__init__.py similarity index 81% rename from torchao/_models/sam2/__init__.py rename to benchmarks/_models/sam2/__init__.py index 0dc11c2fde..f49e12ba4e 100644 --- a/torchao/_models/sam2/__init__.py +++ b/benchmarks/_models/sam2/__init__.py @@ -8,4 +8,4 @@ from hydra.core.global_hydra import GlobalHydra if not GlobalHydra.instance().is_initialized(): - initialize_config_module("torchao._models.sam2", version_base="1.2") + initialize_config_module("benchmarks._models.sam2", version_base="1.2") diff --git a/torchao/_models/sam2/automatic_mask_generator.py b/benchmarks/_models/sam2/automatic_mask_generator.py similarity index 99% rename from torchao/_models/sam2/automatic_mask_generator.py rename to benchmarks/_models/sam2/automatic_mask_generator.py index 6f4f1d3e7b..4e82f3ef04 100644 --- a/torchao/_models/sam2/automatic_mask_generator.py +++ b/benchmarks/_models/sam2/automatic_mask_generator.py @@ -11,9 +11,9 @@ import torch from torchvision.ops.boxes import batched_nms, box_area # type: ignore -from torchao._models.sam2.modeling.sam2_base import SAM2Base -from torchao._models.sam2.sam2_image_predictor import SAM2ImagePredictor -from torchao._models.sam2.utils.amg import ( +from benchmarks._models.sam2.modeling.sam2_base import SAM2Base +from benchmarks._models.sam2.sam2_image_predictor import SAM2ImagePredictor +from benchmarks._models.sam2.utils.amg import ( MaskData, _mask_to_rle_pytorch_2_0, _mask_to_rle_pytorch_2_1, @@ -33,7 +33,7 @@ uncrop_masks, uncrop_points, ) -from torchao._models.sam2.utils.misc import ( +from benchmarks._models.sam2.utils.misc import ( crop_image, get_image_size, ) diff --git a/torchao/_models/sam2/build_sam.py b/benchmarks/_models/sam2/build_sam.py similarity index 97% rename from torchao/_models/sam2/build_sam.py rename to benchmarks/_models/sam2/build_sam.py index ad0d1fe41c..eea26ccee4 100644 --- a/torchao/_models/sam2/build_sam.py +++ b/benchmarks/_models/sam2/build_sam.py @@ -12,7 +12,7 @@ from hydra.utils import instantiate from omegaconf import OmegaConf -from torchao._models import sam2 +from benchmarks._models import sam2 # Check if the user is running Python from the parent directory of the sam2 repo # (i.e. the directory where this repo is cloned into) -- this is not supported since @@ -106,7 +106,7 @@ def build_sam2_video_predictor( **kwargs, ): hydra_overrides = [ - "++model._target_=torchao._models.sam2.sam2_video_predictor.SAM2VideoPredictor", + "++model._target_=benchmarks._models.sam2.sam2_video_predictor.SAM2VideoPredictor", ] if apply_postprocessing: hydra_overrides_extra = hydra_overrides_extra.copy() diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml b/benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml similarity index 71% rename from torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml rename to benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml index 42cd897c67..1742a20e95 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +++ b/benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml @@ -2,18 +2,18 @@ # Model model: - _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base + _target_: benchmarks._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera + _target_: benchmarks._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 neck: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -24,17 +24,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -45,7 +45,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -57,23 +57,23 @@ model: num_layers: 4 memory_encoder: - _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder + _target_: benchmarks._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler + _target_: benchmarks._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: torchao._models.sam2.modeling.memory_encoder.Fuser + _target_: benchmarks._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock + _target_: benchmarks._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml b/benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml similarity index 72% rename from torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml rename to benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml index ba9dafd489..17bf334745 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml +++ b/benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_l.yaml @@ -2,12 +2,12 @@ # Model model: - _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base + _target_: benchmarks._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera + _target_: benchmarks._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 144 num_heads: 2 stages: [2, 6, 36, 4] @@ -15,9 +15,9 @@ model: window_pos_embed_bkg_spatial_size: [7, 7] window_spec: [8, 4, 16, 8] neck: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -28,17 +28,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -49,7 +49,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -61,23 +61,23 @@ model: num_layers: 4 memory_encoder: - _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder + _target_: benchmarks._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler + _target_: benchmarks._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: torchao._models.sam2.modeling.memory_encoder.Fuser + _target_: benchmarks._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock + _target_: benchmarks._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml b/benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml similarity index 72% rename from torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml rename to benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml index 898898b158..7b5f000254 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml +++ b/benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_s.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base + _target_: benchmarks._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera + _target_: benchmarks._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 11, 2] global_att_blocks: [7, 10, 13] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder + _target_: benchmarks._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler + _target_: benchmarks._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: torchao._models.sam2.modeling.memory_encoder.Fuser + _target_: benchmarks._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock + _target_: benchmarks._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml b/benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml similarity index 72% rename from torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml rename to benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml index c6318f843b..84c6e92e9c 100644 --- a/torchao/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml +++ b/benchmarks/_models/sam2/configs/sam2.1/sam2.1_hiera_t.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base + _target_: benchmarks._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera + _target_: benchmarks._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 7, 2] global_att_blocks: [5, 7, 9] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder + _target_: benchmarks._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler + _target_: benchmarks._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: torchao._models.sam2.modeling.memory_encoder.Fuser + _target_: benchmarks._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock + _target_: benchmarks._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml b/benchmarks/_models/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml similarity index 100% rename from torchao/_models/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml rename to benchmarks/_models/sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml b/benchmarks/_models/sam2/configs/sam2/sam2_hiera_b+.yaml similarity index 70% rename from torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml rename to benchmarks/_models/sam2/configs/sam2/sam2_hiera_b+.yaml index b3ba469471..0f6c1c56cc 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_b+.yaml +++ b/benchmarks/_models/sam2/configs/sam2/sam2_hiera_b+.yaml @@ -2,18 +2,18 @@ # Model model: - _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base + _target_: benchmarks._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera + _target_: benchmarks._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 112 num_heads: 2 neck: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -24,17 +24,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -45,7 +45,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -57,23 +57,23 @@ model: num_layers: 4 memory_encoder: - _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder + _target_: benchmarks._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler + _target_: benchmarks._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: torchao._models.sam2.modeling.memory_encoder.Fuser + _target_: benchmarks._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock + _target_: benchmarks._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_l.yaml b/benchmarks/_models/sam2/configs/sam2/sam2_hiera_l.yaml similarity index 71% rename from torchao/_models/sam2/configs/sam2/sam2_hiera_l.yaml rename to benchmarks/_models/sam2/configs/sam2/sam2_hiera_l.yaml index 59a8a1e36b..4baf4e38eb 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_l.yaml +++ b/benchmarks/_models/sam2/configs/sam2/sam2_hiera_l.yaml @@ -2,12 +2,12 @@ # Model model: - _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base + _target_: benchmarks._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera + _target_: benchmarks._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 144 num_heads: 2 stages: [2, 6, 36, 4] @@ -15,9 +15,9 @@ model: window_pos_embed_bkg_spatial_size: [7, 7] window_spec: [8, 4, 16, 8] neck: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -28,17 +28,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -49,7 +49,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -61,23 +61,23 @@ model: num_layers: 4 memory_encoder: - _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder + _target_: benchmarks._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler + _target_: benchmarks._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: torchao._models.sam2.modeling.memory_encoder.Fuser + _target_: benchmarks._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock + _target_: benchmarks._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml b/benchmarks/_models/sam2/configs/sam2/sam2_hiera_s.yaml similarity index 71% rename from torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml rename to benchmarks/_models/sam2/configs/sam2/sam2_hiera_s.yaml index b051d3be63..84b4b52a8e 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_s.yaml +++ b/benchmarks/_models/sam2/configs/sam2/sam2_hiera_s.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base + _target_: benchmarks._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera + _target_: benchmarks._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 11, 2] global_att_blocks: [7, 10, 13] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder + _target_: benchmarks._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler + _target_: benchmarks._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: torchao._models.sam2.modeling.memory_encoder.Fuser + _target_: benchmarks._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock + _target_: benchmarks._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml b/benchmarks/_models/sam2/configs/sam2/sam2_hiera_t.yaml similarity index 72% rename from torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml rename to benchmarks/_models/sam2/configs/sam2/sam2_hiera_t.yaml index 6b108e708f..b572a7e4ee 100644 --- a/torchao/_models/sam2/configs/sam2/sam2_hiera_t.yaml +++ b/benchmarks/_models/sam2/configs/sam2/sam2_hiera_t.yaml @@ -2,21 +2,21 @@ # Model model: - _target_: torchao._models.sam2.modeling.sam2_base.SAM2Base + _target_: benchmarks._models.sam2.modeling.sam2_base.SAM2Base image_encoder: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.ImageEncoder + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.ImageEncoder scalp: 1 trunk: - _target_: torchao._models.sam2.modeling.backbones.hieradet.Hiera + _target_: benchmarks._models.sam2.modeling.backbones.hieradet.Hiera embed_dim: 96 num_heads: 1 stages: [1, 2, 7, 2] global_att_blocks: [5, 7, 9] window_pos_embed_bkg_spatial_size: [7, 7] neck: - _target_: torchao._models.sam2.modeling.backbones.image_encoder.FpnNeck + _target_: benchmarks._models.sam2.modeling.backbones.image_encoder.FpnNeck position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 256 normalize: true scale: null @@ -27,17 +27,17 @@ model: fpn_interp_model: nearest memory_attention: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttention + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttention d_model: 256 pos_enc_at_input: true layer: - _target_: torchao._models.sam2.modeling.memory_attention.MemoryAttentionLayer + _target_: benchmarks._models.sam2.modeling.memory_attention.MemoryAttentionLayer activation: relu dim_feedforward: 2048 dropout: 0.1 pos_enc_at_attn: false self_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] embedding_dim: 256 @@ -48,7 +48,7 @@ model: pos_enc_at_cross_attn_keys: true pos_enc_at_cross_attn_queries: false cross_attention: - _target_: torchao._models.sam2.modeling.sam.transformer.RoPEAttention + _target_: benchmarks._models.sam2.modeling.sam.transformer.RoPEAttention rope_theta: 10000.0 feat_sizes: [32, 32] rope_k_repeat: True @@ -60,23 +60,23 @@ model: num_layers: 4 memory_encoder: - _target_: torchao._models.sam2.modeling.memory_encoder.MemoryEncoder + _target_: benchmarks._models.sam2.modeling.memory_encoder.MemoryEncoder out_dim: 64 position_encoding: - _target_: torchao._models.sam2.modeling.position_encoding.PositionEmbeddingSine + _target_: benchmarks._models.sam2.modeling.position_encoding.PositionEmbeddingSine num_pos_feats: 64 normalize: true scale: null temperature: 10000 mask_downsampler: - _target_: torchao._models.sam2.modeling.memory_encoder.MaskDownSampler + _target_: benchmarks._models.sam2.modeling.memory_encoder.MaskDownSampler kernel_size: 3 stride: 2 padding: 1 fuser: - _target_: torchao._models.sam2.modeling.memory_encoder.Fuser + _target_: benchmarks._models.sam2.modeling.memory_encoder.Fuser layer: - _target_: torchao._models.sam2.modeling.memory_encoder.CXBlock + _target_: benchmarks._models.sam2.modeling.memory_encoder.CXBlock dim: 256 kernel_size: 7 padding: 3 diff --git a/torchao/_models/sam2/csrc/connected_components.cu b/benchmarks/_models/sam2/csrc/connected_components.cu similarity index 100% rename from torchao/_models/sam2/csrc/connected_components.cu rename to benchmarks/_models/sam2/csrc/connected_components.cu diff --git a/torchao/_models/sam2/map_tensor.py b/benchmarks/_models/sam2/map_tensor.py similarity index 100% rename from torchao/_models/sam2/map_tensor.py rename to benchmarks/_models/sam2/map_tensor.py diff --git a/torchao/_models/sam2/modeling/__init__.py b/benchmarks/_models/sam2/modeling/__init__.py similarity index 100% rename from torchao/_models/sam2/modeling/__init__.py rename to benchmarks/_models/sam2/modeling/__init__.py diff --git a/torchao/_models/sam2/modeling/backbones/__init__.py b/benchmarks/_models/sam2/modeling/backbones/__init__.py similarity index 100% rename from torchao/_models/sam2/modeling/backbones/__init__.py rename to benchmarks/_models/sam2/modeling/backbones/__init__.py diff --git a/torchao/_models/sam2/modeling/backbones/hieradet.py b/benchmarks/_models/sam2/modeling/backbones/hieradet.py similarity index 98% rename from torchao/_models/sam2/modeling/backbones/hieradet.py rename to benchmarks/_models/sam2/modeling/backbones/hieradet.py index 91e98f795e..b56c983c8f 100644 --- a/torchao/_models/sam2/modeling/backbones/hieradet.py +++ b/benchmarks/_models/sam2/modeling/backbones/hieradet.py @@ -13,12 +13,12 @@ import torch.nn.functional as F from iopath.common.file_io import g_pathmgr -from torchao._models.sam2.modeling.backbones.utils import ( +from benchmarks._models.sam2.modeling.backbones.utils import ( PatchEmbed, window_partition, window_unpartition, ) -from torchao._models.sam2.modeling.sam2_utils import MLP, DropPath +from benchmarks._models.sam2.modeling.sam2_utils import MLP, DropPath def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: diff --git a/torchao/_models/sam2/modeling/backbones/image_encoder.py b/benchmarks/_models/sam2/modeling/backbones/image_encoder.py similarity index 98% rename from torchao/_models/sam2/modeling/backbones/image_encoder.py rename to benchmarks/_models/sam2/modeling/backbones/image_encoder.py index 0f0a256867..efa1d963e4 100644 --- a/torchao/_models/sam2/modeling/backbones/image_encoder.py +++ b/benchmarks/_models/sam2/modeling/backbones/image_encoder.py @@ -29,7 +29,7 @@ def __init__( def forward(self, sample: torch.Tensor): # Forward through backbone with torch.autograd.profiler.record_function("self.neck(self.trunk(sample))"): - from torchao._models.sam2.map_tensor import MapTensor, to_map_tensor + from benchmarks._models.sam2.map_tensor import MapTensor, to_map_tensor if isinstance(sample, MapTensor): features, pos = self.neck(self.trunk(sample.elems.flatten(0, 1))) diff --git a/torchao/_models/sam2/modeling/backbones/utils.py b/benchmarks/_models/sam2/modeling/backbones/utils.py similarity index 100% rename from torchao/_models/sam2/modeling/backbones/utils.py rename to benchmarks/_models/sam2/modeling/backbones/utils.py diff --git a/torchao/_models/sam2/modeling/memory_attention.py b/benchmarks/_models/sam2/modeling/memory_attention.py similarity index 97% rename from torchao/_models/sam2/modeling/memory_attention.py rename to benchmarks/_models/sam2/modeling/memory_attention.py index 5ac6288af0..c32707cf31 100644 --- a/torchao/_models/sam2/modeling/memory_attention.py +++ b/benchmarks/_models/sam2/modeling/memory_attention.py @@ -9,8 +9,8 @@ import torch from torch import Tensor, nn -from torchao._models.sam2.modeling.sam.transformer import RoPEAttention -from torchao._models.sam2.modeling.sam2_utils import get_activation_fn, get_clones +from benchmarks._models.sam2.modeling.sam.transformer import RoPEAttention +from benchmarks._models.sam2.modeling.sam2_utils import get_activation_fn, get_clones class MemoryAttentionLayer(nn.Module): diff --git a/torchao/_models/sam2/modeling/memory_encoder.py b/benchmarks/_models/sam2/modeling/memory_encoder.py similarity index 98% rename from torchao/_models/sam2/modeling/memory_encoder.py rename to benchmarks/_models/sam2/modeling/memory_encoder.py index 3796cefd00..84116aa225 100644 --- a/torchao/_models/sam2/modeling/memory_encoder.py +++ b/benchmarks/_models/sam2/modeling/memory_encoder.py @@ -11,7 +11,11 @@ import torch.nn as nn import torch.nn.functional as F -from torchao._models.sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones +from benchmarks._models.sam2.modeling.sam2_utils import ( + DropPath, + LayerNorm2d, + get_clones, +) class MaskDownSampler(nn.Module): diff --git a/torchao/_models/sam2/modeling/position_encoding.py b/benchmarks/_models/sam2/modeling/position_encoding.py similarity index 100% rename from torchao/_models/sam2/modeling/position_encoding.py rename to benchmarks/_models/sam2/modeling/position_encoding.py diff --git a/torchao/_models/sam2/modeling/sam/__init__.py b/benchmarks/_models/sam2/modeling/sam/__init__.py similarity index 100% rename from torchao/_models/sam2/modeling/sam/__init__.py rename to benchmarks/_models/sam2/modeling/sam/__init__.py diff --git a/torchao/_models/sam2/modeling/sam/mask_decoder.py b/benchmarks/_models/sam2/modeling/sam/mask_decoder.py similarity index 99% rename from torchao/_models/sam2/modeling/sam/mask_decoder.py rename to benchmarks/_models/sam2/modeling/sam/mask_decoder.py index 7d25697018..1c29113197 100644 --- a/torchao/_models/sam2/modeling/sam/mask_decoder.py +++ b/benchmarks/_models/sam2/modeling/sam/mask_decoder.py @@ -9,7 +9,7 @@ import torch from torch import nn -from torchao._models.sam2.modeling.sam2_utils import MLP, LayerNorm2d +from benchmarks._models.sam2.modeling.sam2_utils import MLP, LayerNorm2d class MaskDecoder(nn.Module): diff --git a/torchao/_models/sam2/modeling/sam/prompt_encoder.py b/benchmarks/_models/sam2/modeling/sam/prompt_encoder.py similarity index 98% rename from torchao/_models/sam2/modeling/sam/prompt_encoder.py rename to benchmarks/_models/sam2/modeling/sam/prompt_encoder.py index 94b7fda8b2..2c3abbfa34 100644 --- a/torchao/_models/sam2/modeling/sam/prompt_encoder.py +++ b/benchmarks/_models/sam2/modeling/sam/prompt_encoder.py @@ -9,8 +9,8 @@ import torch from torch import nn -from torchao._models.sam2.modeling.position_encoding import PositionEmbeddingRandom -from torchao._models.sam2.modeling.sam2_utils import LayerNorm2d +from benchmarks._models.sam2.modeling.position_encoding import PositionEmbeddingRandom +from benchmarks._models.sam2.modeling.sam2_utils import LayerNorm2d class PromptEncoder(nn.Module): diff --git a/torchao/_models/sam2/modeling/sam/transformer.py b/benchmarks/_models/sam2/modeling/sam/transformer.py similarity index 98% rename from torchao/_models/sam2/modeling/sam/transformer.py rename to benchmarks/_models/sam2/modeling/sam/transformer.py index bf0b58d6fd..3c6d3b83cd 100644 --- a/torchao/_models/sam2/modeling/sam/transformer.py +++ b/benchmarks/_models/sam2/modeling/sam/transformer.py @@ -14,12 +14,12 @@ import torch.nn.functional as F from torch import Tensor, nn -from torchao._models.sam2.modeling.position_encoding import ( +from benchmarks._models.sam2.modeling.position_encoding import ( apply_rotary_enc, compute_axial_cis, ) -from torchao._models.sam2.modeling.sam2_utils import MLP -from torchao._models.sam2.utils.misc import get_sdpa_settings +from benchmarks._models.sam2.modeling.sam2_utils import MLP +from benchmarks._models.sam2.utils.misc import get_sdpa_settings warnings.simplefilter(action="ignore", category=FutureWarning) # Check whether Flash Attention is available (and use it by default) diff --git a/torchao/_models/sam2/modeling/sam2_base.py b/benchmarks/_models/sam2/modeling/sam2_base.py similarity index 99% rename from torchao/_models/sam2/modeling/sam2_base.py rename to benchmarks/_models/sam2/modeling/sam2_base.py index 4c2a24a0ef..c5d1f54829 100644 --- a/torchao/_models/sam2/modeling/sam2_base.py +++ b/benchmarks/_models/sam2/modeling/sam2_base.py @@ -9,10 +9,10 @@ import torch.nn.functional as F from torch.nn.init import trunc_normal_ -from torchao._models.sam2.modeling.sam.mask_decoder import MaskDecoder -from torchao._models.sam2.modeling.sam.prompt_encoder import PromptEncoder -from torchao._models.sam2.modeling.sam.transformer import TwoWayTransformer -from torchao._models.sam2.modeling.sam2_utils import ( +from benchmarks._models.sam2.modeling.sam.mask_decoder import MaskDecoder +from benchmarks._models.sam2.modeling.sam.prompt_encoder import PromptEncoder +from benchmarks._models.sam2.modeling.sam.transformer import TwoWayTransformer +from benchmarks._models.sam2.modeling.sam2_utils import ( MLP, get_1d_sine_pe, select_closest_cond_frames, diff --git a/torchao/_models/sam2/modeling/sam2_utils.py b/benchmarks/_models/sam2/modeling/sam2_utils.py similarity index 99% rename from torchao/_models/sam2/modeling/sam2_utils.py rename to benchmarks/_models/sam2/modeling/sam2_utils.py index 579bfc671a..1c00f534e3 100644 --- a/torchao/_models/sam2/modeling/sam2_utils.py +++ b/benchmarks/_models/sam2/modeling/sam2_utils.py @@ -13,7 +13,7 @@ import torch.nn as nn import torch.nn.functional as F -from torchao._models.sam2.utils.misc import mask_to_box +from benchmarks._models.sam2.utils.misc import mask_to_box def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): diff --git a/torchao/_models/sam2/sam2_hiera_b+.yaml b/benchmarks/_models/sam2/sam2_hiera_b+.yaml similarity index 100% rename from torchao/_models/sam2/sam2_hiera_b+.yaml rename to benchmarks/_models/sam2/sam2_hiera_b+.yaml diff --git a/torchao/_models/sam2/sam2_hiera_l.yaml b/benchmarks/_models/sam2/sam2_hiera_l.yaml similarity index 100% rename from torchao/_models/sam2/sam2_hiera_l.yaml rename to benchmarks/_models/sam2/sam2_hiera_l.yaml diff --git a/torchao/_models/sam2/sam2_hiera_s.yaml b/benchmarks/_models/sam2/sam2_hiera_s.yaml similarity index 100% rename from torchao/_models/sam2/sam2_hiera_s.yaml rename to benchmarks/_models/sam2/sam2_hiera_s.yaml diff --git a/torchao/_models/sam2/sam2_hiera_t.yaml b/benchmarks/_models/sam2/sam2_hiera_t.yaml similarity index 100% rename from torchao/_models/sam2/sam2_hiera_t.yaml rename to benchmarks/_models/sam2/sam2_hiera_t.yaml diff --git a/torchao/_models/sam2/sam2_image_predictor.py b/benchmarks/_models/sam2/sam2_image_predictor.py similarity index 99% rename from torchao/_models/sam2/sam2_image_predictor.py rename to benchmarks/_models/sam2/sam2_image_predictor.py index a4aa1c668c..a2c53bdf0a 100644 --- a/torchao/_models/sam2/sam2_image_predictor.py +++ b/benchmarks/_models/sam2/sam2_image_predictor.py @@ -11,9 +11,9 @@ import torch from PIL.Image import Image -from torchao._models.sam2.modeling.sam2_base import SAM2Base -from torchao._models.sam2.utils.misc import get_image_size -from torchao._models.sam2.utils.transforms import SAM2Transforms +from benchmarks._models.sam2.modeling.sam2_base import SAM2Base +from benchmarks._models.sam2.utils.misc import get_image_size +from benchmarks._models.sam2.utils.transforms import SAM2Transforms class SAM2ImagePredictor(torch.nn.Module): diff --git a/torchao/_models/sam2/sam2_video_predictor.py b/benchmarks/_models/sam2/sam2_video_predictor.py similarity index 99% rename from torchao/_models/sam2/sam2_video_predictor.py rename to benchmarks/_models/sam2/sam2_video_predictor.py index 53b0a11d7c..6715178958 100644 --- a/torchao/_models/sam2/sam2_video_predictor.py +++ b/benchmarks/_models/sam2/sam2_video_predictor.py @@ -10,8 +10,8 @@ import torch from tqdm import tqdm -from torchao._models.sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base -from torchao._models.sam2.utils.misc import ( +from benchmarks._models.sam2.modeling.sam2_base import NO_OBJ_SCORE, SAM2Base +from benchmarks._models.sam2.utils.misc import ( concat_points, fill_holes_in_mask_scores, load_video_frames, @@ -52,7 +52,7 @@ def batch_inference_states(inference_states: list): batched_inference_state = copy.copy(inference_states[0]) - from torchao._models.sam2.map_tensor import to_map_tensor + from benchmarks._models.sam2.map_tensor import to_map_tensor # NOTE: Making a build assumption only images differ all_images = torch.stack([state["images"] for state in inference_states]) diff --git a/torchao/_models/sam2/utils/__init__.py b/benchmarks/_models/sam2/utils/__init__.py similarity index 100% rename from torchao/_models/sam2/utils/__init__.py rename to benchmarks/_models/sam2/utils/__init__.py diff --git a/torchao/_models/sam2/utils/amg.py b/benchmarks/_models/sam2/utils/amg.py similarity index 100% rename from torchao/_models/sam2/utils/amg.py rename to benchmarks/_models/sam2/utils/amg.py diff --git a/torchao/_models/sam2/utils/misc.py b/benchmarks/_models/sam2/utils/misc.py similarity index 100% rename from torchao/_models/sam2/utils/misc.py rename to benchmarks/_models/sam2/utils/misc.py diff --git a/torchao/_models/sam2/utils/transforms.py b/benchmarks/_models/sam2/utils/transforms.py similarity index 97% rename from torchao/_models/sam2/utils/transforms.py rename to benchmarks/_models/sam2/utils/transforms.py index c616233050..2d5e46193b 100644 --- a/torchao/_models/sam2/utils/transforms.py +++ b/benchmarks/_models/sam2/utils/transforms.py @@ -78,7 +78,7 @@ def postprocess_masks( """ Perform PostProcessing on output masks. """ - from torchao._models.sam2.utils.misc import get_connected_components + from benchmarks._models.sam2.utils.misc import get_connected_components masks = masks.float() input_masks = masks @@ -125,7 +125,7 @@ def postprocess_masks_1_channel( """ Perform PostProcessing on output masks. """ - from torchao._models.sam2.utils.misc import get_connected_components + from benchmarks._models.sam2.utils.misc import get_connected_components assert masks.dim() == 4 assert masks.size(1) == 1 diff --git a/torchao/_models/utils.py b/benchmarks/_models/utils.py similarity index 54% rename from torchao/_models/utils.py rename to benchmarks/_models/utils.py index 346feb57ae..dc2648a209 100644 --- a/torchao/_models/utils.py +++ b/benchmarks/_models/utils.py @@ -4,9 +4,13 @@ import os import platform import time +from typing import Optional, Tuple import torch +from benchmarks._models.llama.model import Transformer +from torchao.utils import default_device + def get_arch_name() -> str: if torch.cuda.is_available(): @@ -104,3 +108,88 @@ def write_json_result_local(output_json_path, headers, row): with open(f"{os.path.splitext(output_json_path)[0]}.json", "a") as f: print(json.dumps(record), file=f) + + +def encode_tokens(tokenizer, string, bos=True, device=default_device): + tokens = tokenizer.encode(string) + if bos: + tokens = [tokenizer.bos_id()] + tokens + return torch.tensor(tokens, dtype=torch.int, device=device) + + +def _load_model(checkpoint_path, device, precision): + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + if "model" in checkpoint and "stories" in str(checkpoint_path): + checkpoint = checkpoint["model"] + with torch.device("meta"): + model = Transformer.from_name(checkpoint_path.parent.name) + model.load_state_dict(checkpoint, assign=True) + model = model.to(device=device, dtype=precision) + + return model.eval() + + +def multinomial_sample_one_no_sync( + probs_sort, +): # Does multinomial sampling without a cuda synchronization + q = torch.empty_like(probs_sort).exponential_(1) + return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int) + + +def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None): + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, min(top_k, logits.size(-1))) + pivot = v.select(-1, -1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): + probs = logits_to_probs(logits[:, -1], temperature, top_k) + idx_next = multinomial_sample_one_no_sync(probs) + return idx_next, probs + + +def prefill( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> torch.Tensor: + # input_pos: [B, S] + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs)[0] + + +def decode_one_token( + model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs +) -> Tuple[torch.Tensor, torch.Tensor]: + # input_pos: [B, 1] + assert input_pos.shape[-1] == 1 + logits = model(x, input_pos) + return sample(logits, **sampling_kwargs) + + +def decode_n_tokens( + model: Transformer, + cur_token: torch.Tensor, + input_pos: torch.Tensor, + num_new_tokens: int, + callback=lambda _: _, + **sampling_kwargs, +): + new_tokens, new_probs = [], [] + for i in range(num_new_tokens): + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + next_token, next_prob = next_token.clone(), next_prob.clone() + input_pos += 1 + # in some instances not having this causes weird issues with the stored tokens when you run the next decode_one_token step + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob) + cur_token = next_token + + return new_tokens, new_probs diff --git a/benchmarks/quantized_training/pretrain_llama2.py b/benchmarks/quantized_training/pretrain_llama2.py index 25b37921b6..2eb66f5e6b 100644 --- a/benchmarks/quantized_training/pretrain_llama2.py +++ b/benchmarks/quantized_training/pretrain_llama2.py @@ -22,13 +22,13 @@ from torch.utils.checkpoint import checkpoint from tqdm import tqdm -from torchao import quantize_ -from torchao._models.llama.model import ( +from benchmarks._models.llama.model import ( ModelArgs, RMSNorm, Transformer, transformer_configs, ) +from torchao import quantize_ from torchao.prototype import low_bit_optim from torchao.prototype.quantized_training import ( bitnet_training, diff --git a/docs/source/contributor_guide.rst b/docs/source/contributor_guide.rst index ab6d433e27..c204fdc67d 100644 --- a/docs/source/contributor_guide.rst +++ b/docs/source/contributor_guide.rst @@ -125,11 +125,11 @@ After you have the quantization flow implemented, you can run benchmark and eval Note: llama model (llama2/llama3) is our representative model for memory bound models and sam is our representative model for compute bound models. -* `llama `__ - * `benchmark `__ - * `eval `__ -* `sam `__ - * `benchmark and eval `__ +* `llama `__ + * `benchmark `__ + * `eval `__ +* `sam `__ + * `benchmark and eval `__ Please checkout the ``--help`` option for each of the script to understand the supported options, e.g. you can use ``--profile=profile_path`` to get the chrome trace of the run to understand detailed `chrome trace `__. diff --git a/examples/sam2_amg_server/annotate_with_rle.py b/examples/sam2_amg_server/annotate_with_rle.py index 55e5512011..3c3bbc77b0 100644 --- a/examples/sam2_amg_server/annotate_with_rle.py +++ b/examples/sam2_amg_server/annotate_with_rle.py @@ -14,7 +14,7 @@ ) from tqdm import tqdm -from torchao._models.sam2.utils.amg import area_from_rle, rle_to_mask +from benchmarks._models.sam2.utils.amg import area_from_rle, rle_to_mask def timestamped_print(*args, **kwargs): diff --git a/examples/sam2_amg_server/cli.py b/examples/sam2_amg_server/cli.py index 2f6758b7d3..b5feac395e 100644 --- a/examples/sam2_amg_server/cli.py +++ b/examples/sam2_amg_server/cli.py @@ -12,9 +12,9 @@ show_anns, ) -from torchao._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator -from torchao._models.sam2.build_sam import build_sam2 -from torchao._models.sam2.utils.amg import rle_to_mask +from benchmarks._models.sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator +from benchmarks._models.sam2.build_sam import build_sam2 +from benchmarks._models.sam2.utils.amg import rle_to_mask def main_docstring(): diff --git a/examples/sam2_amg_server/cli_on_modal.py b/examples/sam2_amg_server/cli_on_modal.py index 5fe56eeb1a..d44de90bf7 100644 --- a/examples/sam2_amg_server/cli_on_modal.py +++ b/examples/sam2_amg_server/cli_on_modal.py @@ -84,10 +84,10 @@ def build(self): from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from sam2.build_sam import build_sam2 else: - from torchao._models.sam2.automatic_mask_generator import ( + from benchmarks._models.sam2.automatic_mask_generator import ( SAM2AutomaticMaskGenerator, ) - from torchao._models.sam2.build_sam import build_sam2 + from benchmarks._models.sam2.build_sam import build_sam2 os.chdir(f"{TARGET}ao_src_0/examples/sam2_amg_server") import sys @@ -139,11 +139,11 @@ def build(self): from sam2.utils.amg import mask_to_rle_pytorch as mask_to_rle_pytorch_2 from sam2.utils.amg import rle_to_mask else: - from torchao._models.sam2.utils.amg import ( + from benchmarks._models.sam2.utils.amg import ( mask_to_rle_pytorch_2, rle_to_mask, ) - from torchao._models.sam2.utils.amg import area_from_rle + from benchmarks._models.sam2.utils.amg import area_from_rle self.np = np self.tio = tio diff --git a/examples/sam2_amg_server/compare_rle_lists.py b/examples/sam2_amg_server/compare_rle_lists.py index 7a1c78b846..88be3df491 100644 --- a/examples/sam2_amg_server/compare_rle_lists.py +++ b/examples/sam2_amg_server/compare_rle_lists.py @@ -7,7 +7,7 @@ import torch -# from torchao._models.sam2.utils.amg import rle_to_mask +# from benchmarks._models.sam2.utils.amg import rle_to_mask def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: """Compute a binary mask from an uncompressed RLE.""" h, w = rle["size"] diff --git a/examples/sam2_amg_server/compile_export_utils.py b/examples/sam2_amg_server/compile_export_utils.py index d1c6fc06fa..a1b6b5f891 100644 --- a/examples/sam2_amg_server/compile_export_utils.py +++ b/examples/sam2_amg_server/compile_export_utils.py @@ -4,7 +4,7 @@ import torch -from torchao._models.sam2.sam2_image_predictor import SAM2ImagePredictor +from benchmarks._models.sam2.sam2_image_predictor import SAM2ImagePredictor # Tools used to avoid compilation cold start and dynamo cache lookups # We take the compiled model and export it using the largest @@ -513,18 +513,18 @@ def set_fast( dynamic=True, ) - import torchao + import benchmarks if allow_recompiles: # A bunch of extra compiles at module level # Note that this can cause recompilations! # We might want to guard on that - torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0 = torch.compile( + benchmarks._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0 = torch.compile( fullgraph=True, dynamic=True - )(torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0) - torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1 = torch.compile( + )(benchmarks._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_0) + benchmarks._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1 = torch.compile( fullgraph=True, dynamic=True - )(torchao._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1) + )(benchmarks._models.sam2.utils.amg._mask_to_rle_pytorch_2_0_1) mask_generator.calculate_stability_score = torch.compile( fullgraph=True, dynamic=True )(mask_generator.calculate_stability_score) diff --git a/examples/sam2_amg_server/generate_data.py b/examples/sam2_amg_server/generate_data.py index 50eeccb912..dc82348d0b 100644 --- a/examples/sam2_amg_server/generate_data.py +++ b/examples/sam2_amg_server/generate_data.py @@ -192,7 +192,7 @@ def gen_masks_ao_batch( center_points_label_torch_batch = [ torch.from_numpy(t).unsqueeze(1) for t in center_points_label_batch ] - from torchao._models.sam2.map_tensor import to_map_tensor + from benchmarks._models.sam2.map_tensor import to_map_tensor center_points_torch_batch = list(map(to_map_tensor, center_points_torch_batch)) center_points_label_torch_batch = list( @@ -255,7 +255,7 @@ def gen_masks_ao( center_points_torch = torch.from_numpy(center_points).unsqueeze(1) center_points_label_torch = torch.from_numpy(center_points_label).unsqueeze(1) - from torchao._models.sam2.map_tensor import to_map_tensor + from benchmarks._models.sam2.map_tensor import to_map_tensor center_points_torch = to_map_tensor(center_points_torch) center_points_label_torch = to_map_tensor(center_points_label_torch) @@ -532,11 +532,11 @@ def main( from sam2.build_sam import build_sam2 from sam2.utils.amg import mask_to_rle_pytorch else: - from torchao._models.sam2.automatic_mask_generator import ( + from benchmarks._models.sam2.automatic_mask_generator import ( SAM2AutomaticMaskGenerator, ) - from torchao._models.sam2.build_sam import build_sam2 - from torchao._models.sam2.utils.amg import ( + from benchmarks._models.sam2.build_sam import build_sam2 + from benchmarks._models.sam2.utils.amg import ( mask_to_rle_pytorch_2 as mask_to_rle_pytorch, ) torch.manual_seed(seed) diff --git a/examples/sam2_amg_server/result_batch_size_16.csv b/examples/sam2_amg_server/result_batch_size_16.csv index 4e8c338df4..0d59b0a6cf 100644 --- a/examples/sam2_amg_server/result_batch_size_16.csv +++ b/examples/sam2_amg_server/result_batch_size_16.csv @@ -32,21 +32,21 @@ num-images,total_time,first,p99,baseline,max,export-model,second,furious,environ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch data = self._generate_masks_batch(images) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch all_data = self._process_crop_batch(images, all_crop_boxes, all_layer_idxs, all_orig_size) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch self.predictor.set_image_batch(all_cropped_im) File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -90,21 +90,21 @@ RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch data = self._generate_masks_batch(images) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch all_data = self._process_crop_batch(images, all_crop_boxes, all_layer_idxs, all_orig_size) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch self.predictor.set_image_batch(all_cropped_im) File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -186,21 +186,21 @@ Traceback (most recent call last): File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch data = self._generate_masks_batch(images) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch all_data = self._process_crop_batch(images, all_crop_boxes, all_layer_idxs, all_orig_size) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch self.predictor.set_image_batch(all_cropped_im) File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image backbone_out = self.image_encoder(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -244,21 +244,21 @@ RuntimeError: run_func_( container_handle_, input_handles.data(), input_handles. File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch data = self._generate_masks_batch(images) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch all_data = self._process_crop_batch(images, all_crop_boxes, all_layer_idxs, all_orig_size) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch self.predictor.set_image_batch(all_cropped_im) File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -302,21 +302,21 @@ RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch data = self._generate_masks_batch(images) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch all_data = self._process_crop_batch(images, all_crop_boxes, all_layer_idxs, all_orig_size) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch self.predictor.set_image_batch(all_cropped_im) File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -343,7 +343,7 @@ W0104 14:58:02.413000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:03.167000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:04.568000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(4194304x256, 256x128) - mm 8.7354 ms 100.0% + mm 8.7354 ms 100.0% triton_mm_146 13.3706 ms 65.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_139 17.0872 ms 51.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_141 17.6846 ms 49.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -361,7 +361,7 @@ W0104 14:58:07.799000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:08.210000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:08.894000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(8192x256, 256x2048) - mm 0.2846 ms 100.0% + mm 0.2846 ms 100.0% triton_mm_184 0.4445 ms 64.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_177 0.5668 ms 50.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_179 0.5790 ms 49.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -379,7 +379,7 @@ W0104 14:58:11.387000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:11.755000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:12.364000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(1024x256, 256x256) - mm 0.0186 ms 100.0% + mm 0.0186 ms 100.0% triton_mm_626 0.0359 ms 51.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_627 0.0361 ms 51.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_625 0.0365 ms 50.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 @@ -397,7 +397,7 @@ W0104 14:58:14.841000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:15.202000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:15.806000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(1024x256, 256x256) - mm 0.0180 ms 100.0% + mm 0.0180 ms 100.0% triton_mm_646 0.0357 ms 50.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_645 0.0360 ms 49.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_644 0.0370 ms 48.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 @@ -415,7 +415,7 @@ W0104 14:58:16.861000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:17.223000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:17.833000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(1024x256, 256x256) - mm 0.0185 ms 100.0% + mm 0.0185 ms 100.0% triton_mm_682 0.0360 ms 51.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_681 0.0364 ms 50.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_680 0.0365 ms 50.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 @@ -433,7 +433,7 @@ W0104 14:58:18.895000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:19.255000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:19.866000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(1024x256, 256x256) - mm 0.0186 ms 100.0% + mm 0.0186 ms 100.0% triton_mm_736 0.0360 ms 51.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_737 0.0360 ms 51.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_735 0.0365 ms 50.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 @@ -451,7 +451,7 @@ W0104 14:58:20.929000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:21.292000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:21.909000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(1024x256, 256x256) - mm 0.0180 ms 100.0% + mm 0.0180 ms 100.0% triton_mm_792 0.0361 ms 50.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_791 0.0363 ms 49.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_790 0.0370 ms 48.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 @@ -469,7 +469,7 @@ W0104 14:58:22.960000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:23.317000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:23.931000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(1024x256, 256x256) - mm 0.0185 ms 100.0% + mm 0.0185 ms 100.0% triton_mm_847 0.0361 ms 51.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_846 0.0363 ms 51.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_845 0.0368 ms 50.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 @@ -483,7 +483,7 @@ SingleProcess AUTOTUNE benchmarking takes 2.0045 seconds and 0.0040 seconds prec AUTOTUNE mm(1024x256, 256x4) triton_mm_883 0.0162 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2 triton_mm_884 0.0162 ms 99.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2 - mm 0.0166 ms 97.5% + mm 0.0166 ms 97.5% triton_mm_885 0.0232 ms 69.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_889 0.0233 ms 69.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_890 0.0235 ms 68.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -501,7 +501,7 @@ AUTOTUNE mm(2048x2, 2x128) triton_mm_5 0.0073 ms 91.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4 triton_mm_7 0.0077 ms 86.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_6 0.0078 ms 85.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8 - mm 0.0079 ms 84.2% + mm 0.0079 ms 84.2% triton_mm_8 0.0083 ms 80.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4 SingleProcess AUTOTUNE benchmarking takes 2.3704 seconds and 0.0024 seconds precompiling for 17 choices E0104 14:58:30.506000 1111794 site-packages/torch/_inductor/select_algorithm.py:1400] [0/0] Exception out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_10/amg_fast_export_inductor_cache_dir/cz/cczuf4mbz67rz32kb4erom4hh3extdrznp22adm5ibnzg5hixbva.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4) @@ -511,8 +511,8 @@ W0104 14:58:31.755000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:32.124000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:32.745000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE addmm(8192x256, 8192x256, 256x256) - bias_addmm 0.0492 ms 100.0% - addmm 0.0681 ms 72.3% + bias_addmm 0.0492 ms 100.0% + addmm 0.0681 ms 72.3% triton_mm_27 0.0801 ms 61.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_29 0.0805 ms 61.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_25 0.0822 ms 59.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -529,8 +529,8 @@ W0104 14:58:33.985000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:34.346000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:34.965000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE addmm(8192x128, 8192x256, 256x128) - bias_addmm 0.0313 ms 100.0% - addmm 0.0400 ms 78.1% + bias_addmm 0.0313 ms 100.0% + addmm 0.0400 ms 78.1% triton_mm_101 0.0577 ms 54.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_105 0.0588 ms 53.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_103 0.0625 ms 50.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -543,7 +543,7 @@ SingleProcess AUTOTUNE benchmarking takes 2.1995 seconds and 0.0039 seconds prec E0104 14:58:34.979000 1111794 site-packages/torch/_inductor/select_algorithm.py:1400] [0/0] Exception out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_10/amg_fast_export_inductor_cache_dir/f5/cf54lpxyskhyrlnsvgwdvrzswqz4avvyso3u2cqlseqwgbpj7pgv.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8) W0104 14:58:37.259000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(8192x128, 128x256) - mm 0.0332 ms 100.0% + mm 0.0332 ms 100.0% triton_mm_162 0.0454 ms 73.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_160 0.0457 ms 72.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_158 0.0467 ms 71.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -561,7 +561,7 @@ W0104 14:58:38.545000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:38.959000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:39.649000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(8192x2048, 2048x256) - mm 0.2634 ms 100.0% + mm 0.2634 ms 100.0% triton_mm_198 0.5623 ms 46.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_200 0.5694 ms 46.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_196 0.5824 ms 45.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -579,7 +579,7 @@ W0104 14:58:40.825000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:41.198000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:41.816000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(8192x256, 256x256) - mm 0.0553 ms 100.0% + mm 0.0553 ms 100.0% triton_mm_350 0.0801 ms 69.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_352 0.0803 ms 68.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_348 0.0818 ms 67.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -593,7 +593,7 @@ SingleProcess AUTOTUNE benchmarking takes 2.1333 seconds and 0.0039 seconds prec E0104 14:58:41.828000 1111794 site-packages/torch/_inductor/select_algorithm.py:1400] [0/0] Exception out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. for benchmark choice TritonTemplateCaller(/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_10/amg_fast_export_inductor_cache_dir/sn/csnohx66tfenmoj7n2bmwgbic34up2jtkkpubt6ri3ulzzs65i4x.py, ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8) W0104 14:58:48.250000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE mm(4194304x128, 128x256) - mm 9.4713 ms 100.0% + mm 9.4713 ms 100.0% triton_mm_279 13.9709 ms 67.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_272 17.6967 ms 53.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_274 18.6221 ms 50.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -611,8 +611,8 @@ W0104 14:58:52.143000 1111794 site-packages/torch/_inductor/select_algorithm.py: W0104 14:58:52.895000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 294912, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. W0104 14:58:54.313000 1111794 site-packages/torch/_inductor/select_algorithm.py:1619] [0/0] out of resource: shared memory, Required: 262144, Hardware limit: 232448. Reducing block sizes or `num_stages` may help. AUTOTUNE addmm(4194304x128, 4194304x256, 256x128) - bias_addmm 8.5930 ms 100.0% - addmm 11.2420 ms 76.4% + bias_addmm 8.5930 ms 100.0% + addmm 11.2420 ms 76.4% triton_mm_393 13.5410 ms 63.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_386 17.1705 ms 50.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 triton_mm_388 17.8044 ms 48.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4 @@ -623,10 +623,10 @@ AUTOTUNE addmm(4194304x128, 4194304x256, 256x128) triton_mm_391 28.6740 ms 30.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8 SingleProcess AUTOTUNE benchmarking takes 6.0558 seconds and 0.0034 seconds precompiling for 21 choices AUTOTUNE addmm(1024x32, 1024x256, 256x32) - bias_addmm 0.0174 ms 100.0% + bias_addmm 0.0174 ms 100.0% triton_mm_664 0.0227 ms 76.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4 triton_mm_663 0.0227 ms 76.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4 - addmm 0.0228 ms 76.3% + addmm 0.0228 ms 76.3% triton_mm_662 0.0333 ms 52.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=32, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=1, num_warps=2 triton_mm_665 0.0354 ms 49.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=8 triton_mm_669 0.0354 ms 49.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=32, B_PROLOGUE_CAST_TYPE=None, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8 @@ -699,21 +699,21 @@ Traceback (most recent call last): File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 243, in generate_batch data = self._generate_masks_batch(images) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 292, in _generate_masks_batch all_data = self._process_crop_batch(images, all_crop_boxes, all_layer_idxs, all_orig_size) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/automatic_mask_generator.py"", line 384, in _process_crop_batch self.predictor.set_image_batch(all_cropped_im) File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image backbone_out = self.image_encoder(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -773,10 +773,10 @@ RuntimeError: run_func_( container_handle_, input_handles.data(), input_handles. File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -819,10 +819,10 @@ RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -903,10 +903,10 @@ Traceback (most recent call last): File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image backbone_out = self.image_encoder(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -949,10 +949,10 @@ RuntimeError: run_func_( container_handle_, input_handles.data(), input_handles. File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -995,10 +995,10 @@ RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -1079,10 +1079,10 @@ Traceback (most recent call last): File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image backbone_out = self.image_encoder(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -1142,10 +1142,10 @@ RuntimeError: run_func_( container_handle_, input_handles.data(), input_handles. File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -1188,10 +1188,10 @@ RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -1272,10 +1272,10 @@ Traceback (most recent call last): File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image backbone_out = self.image_encoder(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -1318,10 +1318,10 @@ RuntimeError: run_func_( container_handle_, input_handles.data(), input_handles. File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -1364,10 +1364,10 @@ RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 474, in forward_image backbone_out[""backbone_fpn""][0] = self.sam_mask_decoder.conv_s0( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl @@ -1387,7 +1387,7 @@ RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR ,,,,,,,,,{'TORCHINDUCTOR_CACHE_DIR': '/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_10/mps_fast_export_inductor_cache_dir'},mps_ao_ppb_None_fast_export_gpu_preproc,82.75403904914856,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_10/exported_models/mps_ao_fast,,,16,,,,,,/home/cpuhrsch/blogs/tmp/sam2_amg_example_run_10/amg_baseline_annotations,,,,"W0104 18:14:14.202000 2235960 site-packages/torch/_logging/_internal.py:1084] [0/0] Profiler function will be ignored /home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/_inductor/compile_fx.py:222: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance. warnings.warn( -V0104 18:14:58.688000 2235960 site-packages/torch/_dynamo/guards.py:2760] [0/1] [__recompiles] Recompiling function _predict_masks in /home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py:432 +V0104 18:14:58.688000 2235960 site-packages/torch/_dynamo/guards.py:2760] [0/1] [__recompiles] Recompiling function _predict_masks in /home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py:432 V0104 18:14:58.688000 2235960 site-packages/torch/_dynamo/guards.py:2760] [0/1] [__recompiles] triggered by the following guard failure(s): V0104 18:14:58.688000 2235960 site-packages/torch/_dynamo/guards.py:2760] [0/1] [__recompiles] - 0/0: Ne(L['self']._modules['model']._modules['sam_mask_decoder']._modules['transformer']._modules['final_attn_token_to_image'].num_heads*((128//L['self']._modules['model']._modules['sam_mask_decoder']._modules['transformer']._modules['final_attn_token_to_image'].num_heads)), 8*L['point_coords'].elems.size()[0]) # (_inductor/pattern_matcher.py:1288 in ) [E104 18:15:24.766972949 shim_common.cpp:376] Exception in aoti_torch: CUDA out of memory. Tried to allocate 576.00 MiB. GPU 0 has a total capacity of 94.99 GiB of which 498.44 MiB is free. Including non-PyTorch memory, this process has 94.49 GiB memory in use. Of the allocated memory 91.63 GiB is allocated by PyTorch, and 1.31 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables) @@ -1454,10 +1454,10 @@ Traceback (most recent call last): File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/utils/_contextlib.py"", line 116, in decorate_context return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/sam2_image_predictor.py"", line 172, in set_image_batch backbone_out = self.model.forward_image(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - File ""/home/cpuhrsch/dev/ao/torchao/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image + File ""/home/cpuhrsch/dev/ao/benchmarks/_models/sam2/modeling/sam2_base.py"", line 469, in forward_image backbone_out = self.image_encoder(img_batch) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File ""/home/cpuhrsch/.conda/envs/nightly20241126py312/lib/python3.12/site-packages/torch/nn/modules/module.py"", line 1740, in _wrapped_call_impl diff --git a/examples/sam2_amg_server/server.py b/examples/sam2_amg_server/server.py index 7e35858590..ea9953dbed 100644 --- a/examples/sam2_amg_server/server.py +++ b/examples/sam2_amg_server/server.py @@ -26,7 +26,7 @@ from fastapi.responses import StreamingResponse from torch._inductor import config as inductorconfig -from torchao._models.utils import ( +from benchmarks._models.utils import ( get_arch_name, write_json_result_local, write_json_result_ossci, @@ -460,11 +460,11 @@ def main( from sam2.build_sam import build_sam2 from sam2.utils.amg import rle_to_mask else: - from torchao._models.sam2.automatic_mask_generator import ( + from benchmarks._models.sam2.automatic_mask_generator import ( SAM2AutomaticMaskGenerator, ) - from torchao._models.sam2.build_sam import build_sam2 - from torchao._models.sam2.utils.amg import rle_to_mask + from benchmarks._models.sam2.build_sam import build_sam2 + from benchmarks._models.sam2.utils.amg import rle_to_mask device = "cuda" sam2_checkpoint, model_cfg = model_type_to_paths(checkpoint_path, model_type) diff --git a/examples/sam2_vos_example/compile_export_utils.py b/examples/sam2_vos_example/compile_export_utils.py index 7d1b3eddf3..00f1b56794 100644 --- a/examples/sam2_vos_example/compile_export_utils.py +++ b/examples/sam2_vos_example/compile_export_utils.py @@ -4,7 +4,7 @@ import torch -from torchao._models.sam2.sam2_video_predictor import SAM2VideoPredictor +from benchmarks._models.sam2.sam2_video_predictor import SAM2VideoPredictor # Tools used to avoid compilation cold start and dynamo cache lookups # We take the compiled model and export it using the largest diff --git a/examples/sam2_vos_example/video_profile.py b/examples/sam2_vos_example/video_profile.py index 8ee9151cc4..44b90bd77b 100644 --- a/examples/sam2_vos_example/video_profile.py +++ b/examples/sam2_vos_example/video_profile.py @@ -280,7 +280,7 @@ def main( if use_baseline: from sam2.build_sam import build_sam2_video_predictor else: - from torchao._models.sam2.build_sam import build_sam2_video_predictor + from benchmarks._models.sam2.build_sam import build_sam2_video_predictor device = "cuda:0" # hydra_overrides_extra = ["++model.compile_image_encoder=true"] @@ -292,7 +292,7 @@ def main( ) predictor._frame_batch_size = frame_batch_size predictor.image_encoder.trunk = predictor.image_encoder.trunk.to(torch.bfloat16) - from torchao._models.sam2.modeling.sam.transformer import RoPEAttention + from benchmarks._models.sam2.modeling.sam.transformer import RoPEAttention rope_attention_modules = [ module for module in predictor.modules() if isinstance(module, RoPEAttention) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index e05f23da2a..1b0939c951 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -14,7 +14,7 @@ import torch from safetensors.torch import load_file as load_safetensors_file -from torchao._models.llama.model import ModelArgs +from benchmarks._models.llama.model import ModelArgs @torch.inference_mode() diff --git a/test/prototype/test_spinquant.py b/test/prototype/test_spinquant.py index 42606b014e..a50b9d9cb7 100644 --- a/test/prototype/test_spinquant.py +++ b/test/prototype/test_spinquant.py @@ -1,7 +1,7 @@ import pytest import torch -from torchao._models.llama.model import Transformer +from benchmarks._models.llama.model import Transformer from torchao.prototype.spinquant import apply_spinquant diff --git a/test/quantization/test_gptq_mt.py b/test/quantization/test_gptq_mt.py index 5d4e73ed61..f82315714b 100644 --- a/test/quantization/test_gptq_mt.py +++ b/test/quantization/test_gptq_mt.py @@ -5,8 +5,8 @@ import torch.nn.functional as F from torch.testing._internal.common_utils import run_tests -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao._models.llama.tokenizer import get_tokenizer +from benchmarks._models.llama.model import Transformer, prepare_inputs_for_model +from benchmarks._models.llama.tokenizer import get_tokenizer from torchao.quantization.GPTQ_MT import Int4WeightOnlyGPTQQuantizer, MultiTensor from torchao.quantization.utils import _lm_eval_available from torchao.utils import is_fbcode diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 4af429940f..1176367a3d 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -21,9 +21,9 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_utils import TestCase +from benchmarks._models.llama.model import Transformer, prepare_inputs_for_model +from benchmarks._models.llama.tokenizer import get_tokenizer from torchao import quantize_ -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao._models.llama.tokenizer import get_tokenizer from torchao.dtypes import AffineQuantizedTensor from torchao.quantization import LinearActivationQuantizedTensor from torchao.quantization.quant_api import ( @@ -278,7 +278,7 @@ def test_8da4w_quantizer(self): # https://github.com/pytorch-labs/gpt-fast/blob/6253c6bb054e658d67566150f87329b87815ae63/scripts/convert_hf_checkpoint.py @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_8da4w_gptq_quantizer(self): - from torchao._models._eval import InputRecorder, TransformerEvalWrapper + from benchmarks._models._eval import InputRecorder, TransformerEvalWrapper from torchao.quantization.GPTQ import Int8DynActInt4WeightGPTQQuantizer # should be similar to TorchCompileDynamicQuantizer @@ -348,7 +348,7 @@ def test_8da4w_gptq_quantizer(self): not TORCH_VERSION_AT_LEAST_2_4, "skipping when torch verion is 2.4 or lower" ) def test_8da4w_quantizer_eval(self): - from torchao._models._eval import TransformerEvalWrapper + from benchmarks._models._eval import TransformerEvalWrapper from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer precision = torch.bfloat16 @@ -384,7 +384,7 @@ def test_8da4w_quantizer_eval(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer_int4_weight_only(self): - from torchao._models._eval import ( + from benchmarks._models._eval import ( MultiTensorInputRecorder, TransformerEvalWrapper, ) @@ -454,7 +454,7 @@ def test_gptq_quantizer_int4_weight_only(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_quantizer_int4_weight_only(self): - from torchao._models._eval import TransformerEvalWrapper + from benchmarks._models._eval import TransformerEvalWrapper from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer precision = torch.bfloat16 @@ -492,7 +492,7 @@ def test_quantizer_int4_weight_only(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper(self): - from torchao._models._eval import TransformerEvalWrapper + from benchmarks._models._eval import TransformerEvalWrapper precision = torch.bfloat16 device = "cuda" @@ -525,7 +525,7 @@ def test_eval_wrapper(self): # EVAL IS CURRENTLY BROKEN FOR LLAMA 3, VERY LOW ACCURACY @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper_llama3(self): - from torchao._models._eval import TransformerEvalWrapper + from benchmarks._models._eval import TransformerEvalWrapper precision = torch.bfloat16 device = "cuda" diff --git a/test/test_ao_models.py b/test/test_ao_models.py index 49385b0a99..064e2a9a54 100644 --- a/test/test_ao_models.py +++ b/test/test_ao_models.py @@ -1,7 +1,7 @@ import pytest import torch -from torchao._models.llama.model import Transformer +from benchmarks._models.llama.model import Transformer _AVAILABLE_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) diff --git a/torchao/prototype/awq/README.md b/torchao/prototype/awq/README.md index 1040610db5..5f50f2703c 100644 --- a/torchao/prototype/awq/README.md +++ b/torchao/prototype/awq/README.md @@ -2,7 +2,7 @@ Adapted from https://github.com/mit-han-lab/llm-awq ## Benchmarks -Evaluation perplexity numbers were calculated using the script in awq/example.py Group size of 64 was used for all quantization methods. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the torchao/_models/llama/generate.py script and run on a 1xA100 80GB SXM4 instance. The awq-uint4 quantization method does not use an efficient fused kernel which is why performance is not great. awq-hqq uses tinygemm int4->bf16 kernel + hqq to provide better performance. +Evaluation perplexity numbers were calculated using the script in awq/example.py Group size of 64 was used for all quantization methods. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the benchmarks/_models/llama/generate.py script and run on a 1xA100 80GB SXM4 instance. The awq-uint4 quantization method does not use an efficient fused kernel which is why performance is not great. awq-hqq uses tinygemm int4->bf16 kernel + hqq to provide better performance. | Model | Quantization | Tokens/sec | Throughput (GB/sec) | Peak Mem (GB) | Model Size (GB) | |--------------------|--------------|------------|---------------------|---------------|-----------------| @@ -23,9 +23,3 @@ The following tests were performed using LM eval and groupsize = 128 | | awq-uint4 | 11.409 | 0.519 | 0.756 | 0.577 | | | int4wo-hqq | 11.905 | 0.528 | 0.757 | 0.563 | | | int4wo-128 | 12.380 | 0.502 | 0.753 | 0.548 | - - - - - - diff --git a/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_throughput.py b/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_throughput.py index 12fc77bd9a..251dff5ba0 100644 --- a/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_throughput.py +++ b/torchao/prototype/quantization/mixed_precision/scripts/BO_acc_throughput.py @@ -18,15 +18,19 @@ ) import torchao -from torchao._models.llama.generate import ( +from benchmarks._models.llama.model import ( + KVCache, + Transformer, + prepare_inputs_for_model, +) +from benchmarks._models.llama.tokenizer import get_tokenizer +from benchmarks._models.utils import ( _load_model, decode_one_token, - device_sync, encode_tokens, prefill, ) -from torchao._models.llama.model import Transformer, prepare_inputs_for_model -from torchao._models.llama.tokenizer import get_tokenizer +from torchao.utils import device_sync default_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -99,7 +103,7 @@ def generate( _replace_with_custom_fn_if_matches_filter( model, AffineQuantizedKVCache.from_float, - lambda x, y: isinstance(x, torchao._models.llama.model.KVCache), + lambda x, y: isinstance(x, KVCache), ) # format model input @@ -396,7 +400,7 @@ def run_sequential_BO( args, ): """ - currently use the loader and benchmark code from torchao/_models/llama/generate, + currently use the loader and benchmark code from benchmarks/_models/llama/generate, and use lm_eval for ppl evaluation """ # load tokenizers diff --git a/torchao/prototype/spinquant/spinquant.py b/torchao/prototype/spinquant/spinquant.py index 60ad1a8b41..bfa83a332a 100644 --- a/torchao/prototype/spinquant/spinquant.py +++ b/torchao/prototype/spinquant/spinquant.py @@ -10,7 +10,7 @@ import torch from torch import nn -from torchao._models.llama.model import RMSNorm, Transformer +from benchmarks._models.llama.model import RMSNorm, Transformer from torchao.prototype.spinquant.hadamard_utils import ( apply_exact_had_to_linear, get_hadK, diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index b278e22b3b..02bb73a903 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -79,9 +79,9 @@ def __init__( # trace model for one input one_input = [multi.values[0].cpu() for multi in inputs] # pyre-ignore[16] # needed for GPTQ on the torchao llama model - import torchao + import benchmarks - torchao._models.llama.model.use_index_put_for_kv_cache = True + benchmarks._models.llama.model.use_index_put_for_kv_cache = True exported_model = torch._dynamo.export( model.cpu(), aten_graph=True, pre_dispatch=True, tracing_mode="fake" )(*one_input) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index d2b6e0c016..5610779bfe 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -320,7 +320,7 @@ Note that the workaround is also required for `torch.compile` with `freezing` (` ### KV Cache Quantization We've added kv cache quantization and other features in order to enable long context length (and necessarily memory efficient) inference. -In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](../../torchao/_models/llama/README.md#KV-Cache-Quantization-Memory-Efficient-Inference) +In practice these features alongside int4 weight only quantization allow us to **reduce peak memory by ~55%**, meaning we can Llama3.1-8B inference with a **130k context length with only 18.9 GB of peak memory.** More details can be found [here](../../benchmarks/_models/llama/README.md#KV-Cache-Quantization-Memory-Efficient-Inference) ### Sparse-Marlin @@ -346,7 +346,7 @@ Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. F | | w4a8-g128 | 187.62 | 640.32 | 4.82 | 3.41 | ### Gemlite Triton -Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `gemlite_uintx_weight_only`. An example can be found in `torchao/_models/llama/generate.py`. +Int4 and Int8 quantization using the [Gemlite Triton](https://github.com/mobiusml/gemlite) kernels. You can try it out with the `quantize_` api as above alongside the constructor `gemlite_uintx_weight_only`. An example can be found in `benchmarks/_models/llama/generate.py`. Note: we test on gemlite 0.4.1, but should be able to use any version after that, we'd recommend to use the latest release to get the most recent performance improvements. @@ -362,7 +362,7 @@ We're trying to develop kernels for low bit quantization for intx quantization f | | uintx-4-64-hqq | 8.124 | 47.85 | 213.24 | 11.85 | 4.46 | | | uintx-2-8-hqq | 39.605 | 34.83 | 261.42 | 14.99 | 7.51 | -You try can out these apis with the `quantize_` api as above alongside the config `UIntXWeightOnlyConfig`. An example can be found in in `torchao/_models/llama/generate.py`. +You try can out these apis with the `quantize_` api as above alongside the config `UIntXWeightOnlyConfig`. An example can be found in in `benchmarks/_models/llama/generate.py`. ### int8_dynamic_activation_intx_weight Quantization We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used. @@ -373,7 +373,7 @@ We have kernels that do 8-bit dynamic quantization of activations and uintx grou | | int8_dynamic_activation_intx_weight-4-256-false | 16.03 | 65.81 | NA | 4.11 | | | int8_dynamic_activation_intx_weight-3-256-false | 18.94 | 59.97 | NA | 3.17 | -You can try out these apis with the `quantize_` api as above alongside the constructor `int8_dynamic_activation_intx_weight`. An example can be found in `torchao/_models/llama/generate.py`. +You can try out these apis with the `quantize_` api as above alongside the constructor `int8_dynamic_activation_intx_weight`. An example can be found in `benchmarks/_models/llama/generate.py`. ### Codebook Quantization The benchmarks below were run on a single NVIDIA-A6000 GPU. @@ -385,7 +385,7 @@ The benchmarks below were run on a single NVIDIA-A6000 GPU. | Llama-3.1-8B| Base (bfloat16) | 7.713 | 32.16 | 482.70 | 16.35 | 15.01 | | | codebook-4-64 | 10.095 | 1.73 | 8.63 | 23.11 | 4.98 | -You try can out these apis with the `quantize_` api as above alongside the constructor `codebook_weight_only` an example can be found in in `torchao/_models/llama/generate.py`. +You try can out these apis with the `quantize_` api as above alongside the constructor `codebook_weight_only` an example can be found in in `benchmarks/_models/llama/generate.py`. ### Automatic Inductor Configuration @@ -396,7 +396,7 @@ The `quantize_` and `autoquant` apis now automatically use our recommended induc ## (To be moved to prototype) A16W4 WeightOnly Quantization with GPTQ ```python -from torchao._models._eval import InputRecorder, TransformerEvalWrapper +from benchmarks._models._eval import InputRecorder, TransformerEvalWrapper from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer precision = torch.bfloat16 device = "cuda" diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index b689a3adf4..fced804b65 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -28,7 +28,7 @@ The following benchmarks we ran for sam ViT-h on an NVIDIA-A100-80GB, with batch | | 2:4 sparsity (attn + mlp) | 24.30 | 13429 | 0.5306 | **1.07x** | **91.31%** | | | int8 dynamic quant (attn)
int8 dynamic quant + 2:4 sparsity (mlp lin1)
2:4 sparsity (mlp lin2) | 26.46 | 14865 | 0.5668 | **1.16x** | **97.54%** | -To reproduce our benchmarks please follow these [instructions](/torchao/_models/sam/README.md). +To reproduce our benchmarks please follow these [instructions](/benchmarks/_models/sam/README.md). ### LLama3 diff --git a/torchao/utils.py b/torchao/utils.py index 2a67f8a9c9..c814fd7b27 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -641,6 +641,26 @@ def is_sm_at_least_100(): ) +default_device = ( + "cuda" + if torch.cuda.is_available() + else "xpu" + if torch.xpu.is_available() + else "cpu" +) + + +def device_sync(device): + if "cuda" in device: + torch.cuda.synchronize(device) + elif "xpu" in device: + torch.xpu.synchronize(device) + elif ("cpu" in device) or ("mps" in device): + pass + else: + print(f"device={device} is not yet suppported") + + TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev") TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev") TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev") From b4b5a5e66258e2798bb6bdb302ff777696bcf887 Mon Sep 17 00:00:00 2001 From: Vasiliy Kuznetsov Date: Tue, 4 Mar 2025 07:18:34 -0800 Subject: [PATCH 07/11] roofline estimator: add float8 rowwise and mxfp8 recipe support (#1789) * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] --- benchmarks/float8/float8_roofline.py | 182 +++++++++++--- benchmarks/float8/utils.py | 20 +- torchao/testing/float8/roofline_utils.py | 301 +++++++++++++++++------ 3 files changed, 397 insertions(+), 106 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index d29ee865e6..ac58873bce 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -47,6 +47,7 @@ import pandas as pd import sympy import torch +import torch.nn as nn import torch.utils.benchmark as benchmark import tqdm from torch.profiler import ProfilerActivity, profile @@ -57,8 +58,11 @@ ) from torchao.float8 import ( + Float8LinearConfig, convert_to_float8_training, ) +from torchao.prototype.mx_formats.config import MXLinearConfig +from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear from torchao.testing.float8.roofline_utils import ( get_float8_mem_sympy, get_gemm_time_sympy, @@ -93,17 +97,19 @@ def benchmark_fn_in_sec(f, *args, **kwargs): return measurement.mean -def get_gpu_kernel_time(m, x): +def get_gpu_kernel_time(m, x, grad_output): # warm up for _ in range(2): - m(x).sum().backward() + y = m(x) + y.backward(grad_output) # capture a profiling run activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] n_iter = 5 with profile(activities=activities) as prof: for _ in range(n_iter): - m(x).sum().backward() + y = m(x) + y.backward(grad_output) torch.cuda.synchronize() # get the gpu kernel time and aggregate it num_leaf_tensors = 1 + len(list(m.parameters())) @@ -114,10 +120,28 @@ def get_gpu_kernel_time(m, x): return total_time_s -def get_gemm_times(M, K, N, fast_accum, cache_filename=None): +def get_gemm_times( + gemm_role: str, + M: int, + K: int, + N: int, + fast_accum: bool, + bf16_memory_formats: str, + float8_recipe_name: Optional[str], + mx_recipe_name: Optional[str], + cache_filename=None, +): + assert gemm_role in ("output", "grad_input", "grad_weight"), "unsupported" + assert bf16_memory_formats in ( + "row_major:col_major", + "row_major:row_major", + "col_major:row_major", + ), "unsupported" + # Note: this is definitely not the best way to build a cache, # but it will do for now. if cache_filename is not None: + assert False, "TODO retest this for new arguments" if os.path.isfile(cache_filename): # cache already exists, use it with open(cache_filename, "r") as f: @@ -127,7 +151,7 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None): cache = dict() else: cache = dict() - key = f"{M},{K},{N},{fast_accum}" + key = f"{M},{K},{N},{fast_accum},{bf16_memory_formats}" if key in cache: return cache[key] @@ -135,22 +159,40 @@ def get_gemm_times(M, K, N, fast_accum, cache_filename=None): # bf16 time x_bf16 = torch.randn(M, K, dtype=torch.bfloat16, device=device) - w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() + # w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device).t().contiguous().t() + w_bf16 = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + if bf16_memory_formats == "row_major:col_major": + w_bf16 = w_bf16.t().contiguous().t() + elif bf16_memory_formats == "col_major:row_major": + x_bf16 = x_bf16.t().contiguous().t() + elif bf16_memory_formats == "col_major:row_major": + x_bf16 = x_bf16.t().contiguous().t() + bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16) # f8 time - d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16 - A = torch.zeros(M, K, device=device, dtype=d1) - B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() - scale_a = torch.tensor([1.0], device=device) - scale_b = torch.tensor([1.0], device=device) - - def do_matmul(A, B): - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum - ) + if float8_recipe_name == "rowwise_with_gw_hp" and gemm_role == "grad_weight": + f8_time_s = bf16_time_s + else: + d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, torch.bfloat16 + A = torch.zeros(M, K, device=device, dtype=d1) + B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() + if float8_recipe_name == "tensorwise": + scale_a = torch.tensor([1.0], device=device) + scale_b = torch.tensor([1.0], device=device) + elif float8_recipe_name in ("rowwise", "rowwise_with_gw_hp"): + scale_a = torch.ones(M, 1, device=device) + scale_b = torch.ones(1, N, device=device) + else: + assert False, "TODO add mx gemm here" + + def do_matmul(A, B): + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) - f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) + f8_time_s = get_gpu_kernel_gemm_time_s(do_matmul, A, B) # save to cache if needed if cache_filename is not None: @@ -164,33 +206,52 @@ def do_matmul(A, B): def run( outfile: str, do_benchmarks: bool = True, - shape_gen_name: str = "square", + shape_gen_name: str = "pow2", gemm_cache_filename: Optional[str] = None, n_limit: Optional[int] = None, + float8_recipe_name: Optional[str] = None, + mx_recipe_name: Optional[str] = None, + enable_fusion_modeling: bool = False, ): """ Args: * `do_benchmarks`: if True, gemm and e2e fwd+bwd of LNLinearSigmoid are benchmarked - * `shape_gen_name`: `llama`, `square`, or `sweep` + * `shape_gen_name`: `llama`, `pow2`, `pow2_extended`, or `sweep` * `gemm_cache_filename (optional)`: file to cache gemm benchmark results * `n_limit (optional)`: if specified, only runs `n_limit` iterations + * `enable_fusion_modeling`: if False uses Linear, if True uses LNLinearSigmoid and models the fusion of float8 overhead """ + assert not ( + (float8_recipe_name is not None) and (mx_recipe_name is not None) + ), "unsupported" + if float8_recipe_name is None and mx_recipe_name is None: + float8_recipe_name = "tensorwise" + + print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"do_benchmarks: {do_benchmarks}") print(f"shape_gen_name: {shape_gen_name}") + print(f"float8_recipe_name: {float8_recipe_name}") + print(f"mx_recipe_name: {mx_recipe_name}") + print(f"enable_fusion_modeling: {enable_fusion_modeling}") M, K, N = sympy.symbols("M K N") - fp8_mem_time_sympy_dyn_nolimit = get_float8_mem_sympy( + fp8_ovhd_time_sympy = get_float8_mem_sympy( M, K, N, + float8_recipe_name, + mx_recipe_name, + enable_fusion_modeling, + ) + bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16, None, None) + fp8_gemm_time_sympy = get_gemm_time_sympy( + M, K, N, torch.float8_e4m3fn, float8_recipe_name, mx_recipe_name ) - - bf16_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.bfloat16) print("bf16_gemm_time_sympy", bf16_gemm_time_sympy) - fp8_gemm_time_sympy = get_gemm_time_sympy(M, K, N, torch.float8_e4m3fn) print("fp8_gemm_time_sympy", fp8_gemm_time_sympy) + print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy) print() headers = [ @@ -217,6 +278,9 @@ def run( # the difference is the fwd+bwd ln and sigmoid terms, for now to keep things simple # we don't break them out and don't have a roofline for them. "b_fp8_e2e_spdp", + # how well benchmarked gemms match roofline predicted gemms + "rb_bf16_gemm_ratio", + "rb_fp8_gemm_ratio", ] results = [] @@ -237,43 +301,96 @@ def run( # if enabled, also measured observed gemm time b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0 + rb_bf16_gemm_ratio = -1 + rb_fp8_gemm_ratio = -1 + if do_benchmarks: + # TODO(future): make the bf16 gemm times exactly match the e2e + # benchmarks, there is a slight deviation, probably related to gemm + # operand memory formats/transpositions below not exactly matching + # what PyTorch core is doing for `torch.mm` + # input @ weight_t = output bf16_g1, f8_g1 = get_gemm_times( - M_val, K_val, N_val, True, gemm_cache_filename + "output", + M_val, + K_val, + N_val, + True, + "row_major:col_major", + float8_recipe_name, + mx_recipe_name, + gemm_cache_filename, ) + # grad_output @ weight = grad_input bf16_g2, f8_g2 = get_gemm_times( - M_val, N_val, K_val, False, gemm_cache_filename + "grad_input", + M_val, + N_val, + K_val, + False, + "row_major:row_major", + float8_recipe_name, + mx_recipe_name, + gemm_cache_filename, ) + # input_t @ grad_output = grad_weight bf16_g3, f8_g3 = get_gemm_times( - K_val, M_val, N_val, False, gemm_cache_filename + "grad_weight", + K_val, + M_val, + N_val, + False, + "col_major:row_major", + float8_recipe_name, + mx_recipe_name, + gemm_cache_filename, ) b_bf16_gemm_time_s = bf16_g1 + bf16_g2 + bf16_g3 b_fp8_gemm_time_s = f8_g1 + f8_g2 + f8_g3 + rb_bf16_gemm_ratio = r_bf16_gemm_time_s / b_bf16_gemm_time_s + rb_fp8_gemm_ratio = r_fp8_gemm_time_s / b_fp8_gemm_time_s # note: cast from sympy.core.numbers.Float to float to make pandas formatting work r_fp8_ovhd_time_s = float( - fp8_mem_time_sympy_dyn_nolimit.subs(M, M_val).subs(K, K_val).subs(N, N_val) + fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val) ) b_bf16_e2e_time_s, b_fp8_e2e_time_s = 0, 0 if do_benchmarks: # create the model - m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16() + if enable_fusion_modeling: + m_orig = LNLinearSigmoid(K_val, N_val).cuda().bfloat16() + else: + m_orig = ( + nn.Sequential(nn.Linear(K_val, N_val, bias=False)).cuda().bfloat16() + ) x = torch.randn( M_val, K_val, dtype=torch.bfloat16, device="cuda" ).requires_grad_() + # get the gradient of the right shape + grad_output = torch.randn(N_val, K_val, dtype=torch.bfloat16, device="cuda") + # get the bf16 gpu kernel time torch._dynamo.reset() m_bf16 = torch.compile(copy.deepcopy(m_orig)) - b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x) + b_bf16_e2e_time_s = get_gpu_kernel_time(m_bf16, x, grad_output) # get the float8 dynamic scaling gpu kernel time torch._dynamo.reset() - m_fp8_dyn = convert_to_float8_training(copy.deepcopy(m_orig)) + if float8_recipe_name is not None: + config = Float8LinearConfig.from_recipe_name(float8_recipe_name) + m_fp8_dyn = convert_to_float8_training( + copy.deepcopy(m_orig), config=config + ) + else: + assert mx_recipe_name is not None + config = MXLinearConfig.from_recipe_name(mx_recipe_name) + m_fp8_dyn = copy.deepcopy(m_orig) + swap_linear_with_mx_linear(m_fp8_dyn, config=config) m_fp8_dyn = torch.compile(m_fp8_dyn) - b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x) + b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output) results.append( [ @@ -295,6 +412,9 @@ def run( b_bf16_e2e_time_s, b_fp8_e2e_time_s, b_bf16_e2e_time_s / (b_fp8_e2e_time_s + 1e-20), + # gemm ratios + rb_bf16_gemm_ratio, + rb_fp8_gemm_ratio, ] ) diff --git a/benchmarks/float8/utils.py b/benchmarks/float8/utils.py index f12c836a17..5c05100f4d 100644 --- a/benchmarks/float8/utils.py +++ b/benchmarks/float8/utils.py @@ -152,18 +152,32 @@ def get_name_to_shapes_iter( } return name_to_shapes_70b.items() - elif shape_gen_name == "square": + elif shape_gen_name == "pow2": assert ( M == K == N == None ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" name_to_shapes = {} - min_power_of_2 = 8 # 256 - max_power_of_2 = 15 # 32,768 + min_power_of_2 = 10 # 1024 + max_power_of_2 = 14 # 16,384 for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): val = 2**power_of_2 name_to_shapes[idx] = val, val, val return name_to_shapes.items() + elif shape_gen_name == "pow2_extended": + assert ( + M == K == N == None + ), f"M, K, N arguments not supported for shape_gen_name {shape_gen_name}" + name_to_shapes = {} + min_power_of_2 = 10 # 1024 + max_power_of_2 = 14 # 16,384 + for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)): + val1 = 2**power_of_2 + name_to_shapes[idx * 2] = val1, val1, val1 + val2 = 2**power_of_2 + 2 ** (power_of_2 - 1) + name_to_shapes[idx * 2 + 1] = val2, val2, val2 + return name_to_shapes.items() + elif shape_gen_name == "sweep": assert ( M == K == N == None diff --git a/torchao/testing/float8/roofline_utils.py b/torchao/testing/float8/roofline_utils.py index 458acf8f7b..c7c3b4531e 100644 --- a/torchao/testing/float8/roofline_utils.py +++ b/torchao/testing/float8/roofline_utils.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from typing import List, Optional, Union + +import sympy import torch BYTES_PER_EL_FLOAT8 = 1 @@ -16,8 +19,8 @@ "fp8_peak_tops": 1979e12, # 2.4 TB per second, custom to Meta's H100 variant "peak_mem_bw_bytes_sec": 2.4e12, - # based on quick experimental observation with sample large inputs - "pct_achievable_gemm_tops": 0.6, + # based on experimental observation with sample large inputs + "pct_achievable_gemm_tops": 0.78, # based on previous experience looking at pointwise triton kernels with large inputs, # which would hit about 2.2k GBPS on Meta's H100 variant "pct_achievable_mem_bw": 0.92, @@ -33,7 +36,7 @@ "peak_mem_bw_bytes_sec": 8.0e12, # for now, copy over from H100 # TODO(future): measure once we have the hardware - "pct_achievable_gemm_tops": 0.6, + "pct_achievable_gemm_tops": 0.78, # for now, copy over from H100 # TODO(future): measure once we have the hardware "pct_achievable_mem_bw": 0.92, @@ -49,49 +52,235 @@ def get_specs(): # Source: run a triton kernel with a single element read/write on an H100 and # measure GPU time from the trace -TRITON_KERNEL_1_ELEMENT_TIME_SEC = 0.002 * 0.001 +# TODO(future): audit this across different hardware and triton/non-triton +KERNEL_LAUNCH_OVERHEAD_SEC = 0.002 * 0.001 -def get_tensor_memory_traffic_bytes( +def get_tensor_memory_traffic_ovhd_s( + specs, dim0, dim1, + tensor_role: str, + float8_recipe_name: Optional[str], + mx_recipe_name: Optional[str], fuse_with_prev=False, -): +) -> List[Union[sympy.Symbol, float]]: + """ + Calculates the roofline estimate of casting one of the gemm inputs + (input, weight or grad_output) to float8 in fwd+bwd. + + Inputs: dim0 and dim1 (shape), tensor_role (input|weight|grad_output), recipe names + Outputs: list of read/write traffic overhead in seconds, one for each kernel + """ # assumes input bf16, output f8 numel = dim0 * dim1 - # x_bf16 = ... - # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp - # kernel 2 (not modeled): tmp -> max_abs_stage_2 -> max_abs - # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8 + res_bytes = None + if float8_recipe_name == "tensorwise": + if tensor_role == "weight": + # x_bf16 = ... + # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp + # kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs + # kernel 3 (fwd): x_bf16, max_abs -> to_float8 -> x_fp8_dim0 + # kernel 4 (bwd): x_bf16, max_abs -> to_float8 -> x_fp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + else: + # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + # kernel 3: read in bf16, write twice in float8 (row-major and col-major) + kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_4_rw = kernel_3_rw + res_bytes = [kernel_1_rw, 0, kernel_3_rw, kernel_4_rw] + else: + # x_bf16 = ... + # kernel 1: x_bf16 -> max_abs_stage_1 -> tmp + # kernel 2 (mem traffic not modeled): tmp -> max_abs_stage_2 -> max_abs + # kernel 3: x_bf16, max_abs -> to_float8 -> x_fp8_dim0, x_fp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + else: + # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) + kernel_1_rw = BYTES_PER_EL_BF16 * numel + # kernel 3: read in bf16, write twice in float8 (row-major and col-major) + kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw, 0, kernel_3_rw] + + elif float8_recipe_name == "rowwise": + if tensor_role == "weight": + # x_bf16 = ... + # kernel 1 (fwd): x_bf16_dim0 -> x_float8_dim0 + # kernel 2 (bwd): x_bf16_dim0 -> x_bf16_dim1 + # kernel 3 (bwd): x_bf16_dim1 -> x_float8_dim1 + # assume that we can't fuse 2 and 3 because that would require loading + # the entire tensor to shared memory + if fuse_with_prev: + # assume we can fuse one of the reads with previous op + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_BF16 * numel * 2 + kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw, kernel_2_rw, kernel_3_rw] + else: + # x_bf16 = ... + # kernel 1: x_bf16_dim0 -> x_float8_dim0, x_bf16_dim1 + # kernel 2: x_bf16_dim1 -> x_float8_dim1 + # assume that we can't fuse 1 and 2 because that would require loading + # the entire tensor to shared memory + if fuse_with_prev: + # assume we can fuse one of the reads with previous op + kernel_1_rw = ( + 0 + BYTES_PER_EL_FLOAT8 * numel + BYTES_PER_EL_BF16 * numel + ) + else: + kernel_1_rw = ( + BYTES_PER_EL_BF16 * numel + + BYTES_PER_EL_FLOAT8 * numel + + BYTES_PER_EL_BF16 * numel + ) + kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw, kernel_2_rw] + + elif float8_recipe_name == "rowwise_with_gw_hp": + if tensor_role in ("input", "grad_output"): + # x_bf16 = ... + # kernel 1 (fwd): x_bf16_dim0 -> x_float8_dim0 + # bwd: no-op + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw] + elif tensor_role == "weight": + # x_bf16 = ... + # kernel 1 (fwd): w_bf16 -> w_float8_dim0, w_scale_dim0 + # kernel 2 (bwd): w_scale_dim0 -> w_scale_tensorwise + # kernel 3 (bwd): w_bf16, w_scale_tensorwise -> w_float8_dim1 + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = 0 + kernel_3_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw, kernel_2_rw, kernel_3_rw] + else: + assert False, "unsupported" - if fuse_with_prev: - kernel_1_rw = 0 else: - # kernel 1: read numel, write 0 (assume size(tmp) ~ 0) - kernel_1_rw = BYTES_PER_EL_BF16 * numel + assert mx_recipe_name in ("mxfp8_emulated", "mxfp8_cutlass"), "unsupported" - # kernel 3: read in bf16, write twice in float8 (row-major and col-major) - kernel_3_rw = BYTES_PER_EL_BF16 * numel + 2 * BYTES_PER_EL_FLOAT8 * numel + if tensor_role == "weight": + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_bf16 -> x_mxfp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel + else: + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw, kernel_2_rw] + else: + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8_dim0, x_mxfp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel * 2 + else: + kernel_1_rw = ( + BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel * 2 + ) + res_bytes = [kernel_1_rw] - return kernel_1_rw + kernel_3_rw + # convert from bytes to seconds + res_s = [ + x / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"] + for x in res_bytes + ] + # take max of kernel_overhead, r/w time + res_s = [sympy.Max(x, KERNEL_LAUNCH_OVERHEAD_SEC) for x in res_s] -def get_gemm_time_sympy(M, K, N, dtype): + return res_s + + +def get_individual_gemm_time_sympy( + M: sympy.Symbol, K: sympy.Symbol, N: sympy.Symbol, dtype, mx_recipe_name +) -> sympy.Symbol: + # compute bound specs = get_specs() - gemm_ops = 2 * M * K * N + 2 * M * N * K + 2 * K * M * N + gemm_ops = 2 * M * K * N if dtype is torch.bfloat16: peak_tops = specs["bf16_peak_tops"] elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): peak_tops = specs["fp8_peak_tops"] - gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"] - return gemm_time_s + else: + assert False, "unsupported" + compute_gemm_time_s = gemm_ops / peak_tops / specs["pct_achievable_gemm_tops"] + + # memory bound + num_reads = M * K + K * N + num_writes = M * N + + if mx_recipe_name is not None: + assert mx_recipe_name in ("mxfp8_emulated", "mxfp8_cutlass"), "unsupported" + assert dtype in (torch.float8_e4m3fn, torch.float8_e5m2), "unsupported" + # adjust reads for MX scaling + block_size = 32 + num_scale_reads = num_reads // block_size + # note: e8m0 bytes per element is the same as for e4m3|e5m2 + num_reads = num_reads + num_scale_reads + + if dtype is torch.bfloat16: + bytes_rw = num_reads * BYTES_PER_EL_BF16 + num_writes * BYTES_PER_EL_BF16 + elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + # read in float8, output in bfloat16 + bytes_rw = num_reads * BYTES_PER_EL_FLOAT8 + num_writes * BYTES_PER_EL_BF16 + else: + assert False, "unsupported" + mem_gemm_time_s = ( + bytes_rw / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"] + ) + + return sympy.Max(compute_gemm_time_s, mem_gemm_time_s, KERNEL_LAUNCH_OVERHEAD_SEC) + + +def get_gemm_time_sympy( + M: sympy.Symbol, + K: sympy.Symbol, + N: sympy.Symbol, + dtype, + float8_recipe_name: Optional[str], + mx_recipe_name: Optional[str], +): + # next: add rowwise_with_gw_hp here + # note: this function is currently not super accurate for small shapes: + # when M,K,N <= 1k,1k,1k it undercounts by around 2x + + gemm_dtype_input, gemm_dtype_grad_input, gemm_dtype_grad_weight = ( + dtype, + dtype, + dtype, + ) + if float8_recipe_name == "rowwise_with_gw_hp": + gemm_dtype_grad_weight = torch.bfloat16 + + gemm_output_time_s = get_individual_gemm_time_sympy( + M, K, N, gemm_dtype_input, mx_recipe_name + ) + gemm_grad_input_time_s = get_individual_gemm_time_sympy( + M, N, K, gemm_dtype_grad_input, mx_recipe_name + ) + gemm_grad_weight_time_s = get_individual_gemm_time_sympy( + K, M, N, gemm_dtype_grad_weight, mx_recipe_name + ) + total = gemm_output_time_s + gemm_grad_input_time_s + gemm_grad_weight_time_s + return total def get_float8_mem_sympy( M, K, N, + float8_recipe_name: Optional[str], + mx_recipe_name: Optional[str], + enable_fusion_modeling: bool, ): specs = get_specs() @@ -106,65 +295,33 @@ def get_float8_mem_sympy( # input_t @ grad_output = grad_weight # KxM @ MxN => KxN - # - # forward - output - # - fwd_fp8_input_mem = get_tensor_memory_traffic_bytes( + fwd_fp8_input_mem = get_tensor_memory_traffic_ovhd_s( + specs, M, K, - fuse_with_prev=True, + tensor_role="input", + float8_recipe_name=float8_recipe_name, + mx_recipe_name=mx_recipe_name, + fuse_with_prev=enable_fusion_modeling, ) - fwd_fp8_weight_mem = get_tensor_memory_traffic_bytes( + fwd_fp8_weight_mem = get_tensor_memory_traffic_ovhd_s( + specs, K, N, + tensor_role="weight", + float8_recipe_name=float8_recipe_name, + mx_recipe_name=mx_recipe_name, fuse_with_prev=False, ) - fwd_fp8_total_mem = fwd_fp8_input_mem + fwd_fp8_weight_mem - - # - # backward - grad_input - # - gi_fp8_grad_output_mem = get_tensor_memory_traffic_bytes( + gi_fp8_grad_output_mem = get_tensor_memory_traffic_ovhd_s( + specs, M, N, - fuse_with_prev=True, - ) - # already casted, assuming that we save weight from fw to bw - # TODO: model this if FSDP float8 all-gather is on - # TODO: model this if we don't save weight from fw to bw, and recompute instead - gi_fp8_weight_mem = 0 - - # - # backward - grad_weight - # - # TODO: model this if we don't save fp8 input from fw to bw - gw_fp8_input_t_mem = 0 # already casted - # this should be always 0 - gw_fp8_grad_output_mem = 0 # already casted - - bwd_fp8_total_mem = ( - gi_fp8_grad_output_mem - + gi_fp8_weight_mem - + gw_fp8_input_t_mem - + gw_fp8_grad_output_mem - ) - fp8_total_mem = fwd_fp8_total_mem + bwd_fp8_total_mem - fp8_mem_time_s = ( - fp8_total_mem / specs["peak_mem_bw_bytes_sec"] / specs["pct_achievable_mem_bw"] + tensor_role="grad_output", + float8_recipe_name=float8_recipe_name, + mx_recipe_name=mx_recipe_name, + fuse_with_prev=enable_fusion_modeling, ) - # Adjust final estimate for small kernel launches - # note that we do this adjustment here because we are assuming a minimal - # kernel overhead in the units of seconds, and the per-gemm-input memory - # estimations are in the units of bytes. - num_extra_kernels = 0 - # second stage of max-abs reduction for input - num_extra_kernels += 1 - # second stage of max-abs reduction for weight - num_extra_kernels += 1 - # second stage of max-abs reduction for grad_output - num_extra_kernels += 1 - - extra_kernel_overhead_s = num_extra_kernels * TRITON_KERNEL_1_ELEMENT_TIME_SEC - - return fp8_mem_time_s + extra_kernel_overhead_s + res = sum([*fwd_fp8_input_mem, *fwd_fp8_weight_mem, *gi_fp8_grad_output_mem]) + return res From 3f24e0c87077f918aa63f54cee89282ea1953ddd Mon Sep 17 00:00:00 2001 From: Manuel Candales <42380156+manuelcandales@users.noreply.github.com> Date: Tue, 4 Mar 2025 12:17:10 -0500 Subject: [PATCH 08/11] metal lowbit ops: ci (#1825) --- .../workflows/torchao_experimental_test.yml | 53 ++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 0646eebc03..122e5ad9f6 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -11,7 +11,7 @@ on: - 'gh/**' jobs: - test: + test-cpu-ops: strategy: matrix: runner: [macos-14] @@ -56,3 +56,54 @@ jobs: sh build_and_run_tests.sh rm -rf /tmp/cmake-out popd + + test-mps-ops: + strategy: + matrix: + runner: [macos-m1-stable] + runs-on: ${{matrix.runner}} + steps: + - name: Print machine info + run: | + uname -a + if [ $(uname -s) == Darwin ]; then + sysctl machdep.cpu.brand_string + sysctl machdep.cpu.core_count + fi + - name: Checkout repo + uses: actions/checkout@v3 + with: + submodules: true + - name: Create conda env + run: | + conda create -yn test-mps-ops-env python=3.11 + - name: Activate conda env + run: | + source activate base + conda activate test-mps-ops-env + - name: Install torch + run: | + pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" + - name: Print torch version + run: | + python -c "import torch; print(torch.__version__)" + - name: Install requirements + run: | + pip install cmake + pip install parameterized + pip install pyyaml + - name: Print pip freeze + run: | + pip freeze + - name: Print current directory + run: | + python -c "import os; print(os.getcwd())" + - name: Build ao with experimental mps ops + run: | + USE_CPP=1 TORCHAO_BUILD_EXPERIMENTAL_MPS=1 pip install . + - name: Run mps tests + run: | + pushd torchao/experimental/ops/mps/test + python test_lowbit.py + python test_quantizer.py + popd From 9322ee19bd3b09009b486801f4fd663cfb30233f Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 4 Mar 2025 09:55:55 -0800 Subject: [PATCH 09/11] Fix experimental CI (#1827) init --- .github/workflows/torchao_experimental_test.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index 122e5ad9f6..ba2cd800a6 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -53,8 +53,8 @@ jobs: run: | conda activate venv pushd torchao/experimental/ops/tests - sh build_and_run_tests.sh - rm -rf /tmp/cmake-out + # sh build_and_run_tests.sh + # rm -rf /tmp/cmake-out popd test-mps-ops: @@ -92,6 +92,7 @@ jobs: pip install cmake pip install parameterized pip install pyyaml + pip install numpy - name: Print pip freeze run: | pip freeze From 28ccd7391f88d0acdd8255b9130ce79e548a5d9d Mon Sep 17 00:00:00 2001 From: Scott Roy <161522778+metascroy@users.noreply.github.com> Date: Tue, 4 Mar 2025 10:29:45 -0800 Subject: [PATCH 10/11] Optionally enable KleidiAI + clean up setup.py flags (#1826) * init * up * up * up * up * up * up * up * up * up --- setup.py | 115 +++++++++++++----- torchao/experimental/CMakeLists.txt | 85 +++++++------ .../kernels/cpu/aarch64/CMakeLists.txt | 7 +- .../kernels/cpu/aarch64/tests/CMakeLists.txt | 3 +- .../cpu/aarch64/tests/build_and_run_tests.sh | 1 + .../ops/embedding_xbit/CMakeLists.txt | 8 +- .../embedding_xbit/op_embedding_xbit-impl.h | 12 +- .../CMakeLists.txt | 10 +- .../kernel_selector.h | 9 +- .../op_linear_8bit_act_xbit_weight-impl.h | 4 - torchao/experimental/ops/tests/CMakeLists.txt | 19 ++- .../ops/tests/build_and_run_tests.sh | 3 +- 12 files changed, 173 insertions(+), 103 deletions(-) diff --git a/setup.py b/setup.py index e1bad04cd2..b16f78eb35 100644 --- a/setup.py +++ b/setup.py @@ -9,6 +9,7 @@ import sys import time from datetime import datetime +from typing import List, Optional from setuptools import Extension, find_packages, setup @@ -75,19 +76,54 @@ def use_debug_mode(): CUDAExtension, ) -build_torchao_experimental_mps = ( - os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1" - and build_torchao_experimental - and torch.mps.is_available() -) -if os.getenv("TORCHAO_BUILD_EXPERIMENTAL_MPS") == "1": - if use_cpp != "1": - print("Building experimental MPS ops requires USE_CPP=1") - if not platform.machine().startswith("arm64") or platform.system() != "Darwin": - print("Experimental MPS ops require Apple Silicon.") - if not torch.mps.is_available(): - print("MPS not available. Skipping compilation of experimental MPS ops.") +class BuildOptions: + def __init__(self): + # TORCHAO_BUILD_CPU_AARCH64 is enabled by default on Arm-based Apple machines + # The kernels require sdot/udot, which are not required on Arm until Armv8.4 or later, + # but are available on Arm-based Apple machines. On non-Apple machines, the kernels + # can be built by explicitly setting TORCHAO_BUILD_CPU_AARCH64=1 + self.build_cpu_aarch64 = self._os_bool_var( + "TORCHAO_BUILD_CPU_AARCH64", + default=(self._is_arm64() and self._is_macos()), + ) + if self.build_cpu_aarch64: + assert ( + self._is_arm64() + ), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine" + + # TORCHAO_BUILD_KLEIDIAI is disabled by default for now because + # 1) It increases the build time + # 2) It has some accuracy issues in CI tests due to BF16 + self.build_kleidi_ai = self._os_bool_var( + "TORCHAO_BUILD_KLEIDIAI", default=False + ) + if self.build_kleidi_ai: + assert ( + self.build_cpu_aarch64 + ), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set" + + # TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default. + self.build_experimental_mps = self._os_bool_var( + "TORCHAO_BUILD_EXPERIMENTAL_MPS", default=False + ) + if self.build_experimental_mps: + assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS" + assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64" + assert ( + torch.mps.is_available() + ), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available" + + def _is_arm64(self) -> bool: + return platform.machine().startswith("arm64") + + def _is_macos(self) -> bool: + return platform.system() == "Darwin" + + def _os_bool_var(self, var, default) -> bool: + default_val = "1" if default else "0" + return os.getenv(var, default_val) == "1" + # Constant known variables used throughout this file cwd = os.path.abspath(os.path.curdir) @@ -179,38 +215,30 @@ def build_extensions(self): def build_cmake(self, ext): extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) - build_type = "Debug" if use_debug_mode() else "Release" - - from distutils.sysconfig import get_python_lib - - torch_dir = get_python_lib() + "/torch/share/cmake/Torch" - if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) - build_mps_ops = "ON" if build_torchao_experimental_mps else "OFF" - subprocess.check_call( [ "cmake", - ext.sourcedir, - "-DCMAKE_BUILD_TYPE=" + build_type, - # Disable now because 1) KleidiAI increases build time, and 2) KleidiAI has accuracy issues due to BF16 - "-DTORCHAO_BUILD_KLEIDIAI=OFF", - "-DTORCHAO_BUILD_MPS_OPS=" + build_mps_ops, - "-DTorch_DIR=" + torch_dir, - "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, - "-DCMAKE_INSTALL_PREFIX=cmake-out", - ], + ext.cmake_lists_dir, + ] + + ext.cmake_args + + ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir], cwd=self.build_temp, ) subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp) class CMakeExtension(Extension): - def __init__(self, name, sourcedir=""): + def __init__( + self, name, cmake_lists_dir: str = "", cmake_args: Optional[List[str]] = None + ): Extension.__init__(self, name, sources=[]) - self.sourcedir = os.path.abspath(sourcedir) + self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) + if cmake_args is None: + cmake_args = [] + self.cmake_args = cmake_args def get_extensions(): @@ -310,10 +338,33 @@ def get_extensions(): ) if build_torchao_experimental: + build_options = BuildOptions() + + def bool_to_on_off(value): + return "ON" if value else "OFF" + + from distutils.sysconfig import get_python_lib + + torch_dir = get_python_lib() + "/torch/share/cmake/Torch" + ext_modules.append( CMakeExtension( "torchao.experimental", - sourcedir="torchao/experimental", + cmake_lists_dir="torchao/experimental", + cmake_args=( + [ + f"-DCMAKE_BUILD_TYPE={'Debug' if use_debug_mode() else 'Release'}", + f"-DTORCHAO_BUILD_CPU_AARCH64={bool_to_on_off(build_options.build_cpu_aarch64)}", + f"-DTORCHAO_BUILD_KLEIDIAI={bool_to_on_off(build_options.build_kleidi_ai)}", + f"-DTORCHAO_BUILD_MPS_OPS={bool_to_on_off(build_options.build_experimental_mps)}", + "-DTorch_DIR=" + torch_dir, + ] + + ( + ["-DCMAKE_INSTALL_PREFIX=cmake-out"] + if build_options.build_experimental_mps + else [] + ) + ), ) ) diff --git a/torchao/experimental/CMakeLists.txt b/torchao/experimental/CMakeLists.txt index 67dfc7b779..e161cb8946 100644 --- a/torchao/experimental/CMakeLists.txt +++ b/torchao/experimental/CMakeLists.txt @@ -17,17 +17,13 @@ endif() option(TORCHAO_BUILD_EXECUTORCH_OPS "Building torchao ops for ExecuTorch." OFF) option(TORCHAO_BUILD_MPS_OPS "Building torchao MPS ops" OFF) - +option(TORCHAO_BUILD_CPU_AARCH64 "Build torchao's CPU aarch64 kernels" OFF) +option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) if(NOT TORCHAO_INCLUDE_DIRS) set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../..) endif() -option(TORCHAO_BUILD_KLEIDIAI "Download, build, and link against Arm KleidiAI library (arm64 only)" OFF) -if(TORCHAO_BUILD_KLEIDIAI) - message(STATUS "Building with Arm KleidiAI library") - add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1) -endif() include(CMakePrintHelpers) add_compile_options("-Wall" "-Werror" "-Wno-deprecated") @@ -36,49 +32,52 @@ include(CMakePrintHelpers) message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}") include_directories(${TORCHAO_INCLUDE_DIRS}) -if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") + +if(TORCHAO_BUILD_CPU_AARCH64) + message(STATUS "Building with cpu/aarch64") + add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) + + # Defines torchao_kernels_aarch64 + add_subdirectory(kernels/cpu/aarch64) + if(TORCHAO_BUILD_KLEIDIAI) message(STATUS "Building with Arm KleidiAI library") - add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1) - endif() - # Defines target torchao_kernels_aarch64 - add_subdirectory(kernels/cpu/aarch64) - add_subdirectory(ops/linear_8bit_act_xbit_weight) - add_subdirectory(ops/embedding_xbit) - - add_library(torchao_ops_aten SHARED) - target_link_libraries( - torchao_ops_aten PRIVATE - torchao_ops_linear_8bit_act_xbit_weight_aten - torchao_ops_embedding_xbit_aten - ) - if (TORCHAO_BUILD_MPS_OPS) - message(STATUS "Building with MPS support") - add_subdirectory(ops/mps) - target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten) + add_compile_definitions(TORCHAO_ENABLE_KLEIDI) endif() +endif() + +add_subdirectory(ops/linear_8bit_act_xbit_weight) +add_subdirectory(ops/embedding_xbit) +add_library(torchao_ops_aten SHARED) +target_link_libraries( + torchao_ops_aten PRIVATE + torchao_ops_linear_8bit_act_xbit_weight_aten + torchao_ops_embedding_xbit_aten +) +if (TORCHAO_BUILD_MPS_OPS) + message(STATUS "Building with MPS support") + add_subdirectory(ops/mps) + target_link_libraries(torchao_ops_aten PRIVATE torchao_ops_mps_aten) +endif() + +install( + TARGETS torchao_ops_aten + EXPORT _targets + DESTINATION lib +) +if(TORCHAO_BUILD_EXECUTORCH_OPS) + add_library(torchao_ops_executorch STATIC) + target_link_libraries(torchao_ops_executorch PRIVATE + torchao_ops_linear_8bit_act_xbit_weight_executorch + torchao_ops_embedding_xbit_executorch + ) install( - TARGETS torchao_ops_aten + TARGETS + torchao_ops_executorch + torchao_ops_linear_8bit_act_xbit_weight_executorch + torchao_ops_embedding_xbit_executorch EXPORT _targets DESTINATION lib ) - if(TORCHAO_BUILD_EXECUTORCH_OPS) - add_library(torchao_ops_executorch STATIC) - target_link_libraries(torchao_ops_executorch PRIVATE - torchao_ops_linear_8bit_act_xbit_weight_executorch - torchao_ops_embedding_xbit_executorch - ) - install( - TARGETS - torchao_ops_executorch - torchao_kernels_aarch64 - torchao_ops_linear_8bit_act_xbit_weight_executorch - torchao_ops_embedding_xbit_executorch - EXPORT _targets - DESTINATION lib - ) - endif() -else() - message(FATAL_ERROR "Torchao experimental ops can only be built on arm64 CPUs.") endif() diff --git a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt index bb4d9ac22f..3cca338cbf 100644 --- a/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt @@ -4,7 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")) +if (TORCHAO_BUILD_CPU_AARCH64) add_library( torchao_kernels_aarch64 ${CMAKE_CURRENT_SOURCE_DIR}/reduction/find_min_and_max.cpp @@ -22,14 +22,11 @@ if ((CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") OR (CMAKE_SYSTEM_PROCESSOR STREQUA GIT_TAG v1.2.0) FetchContent_MakeAvailable(kleidiai) - # Temporarily exposing this to the parent scope until we wire - # this up properly from the top level - set(TORCHAO_BUILD_KLEIDI ON PARENT_SCOPE) target_link_libraries(torchao_kernels_aarch64 PUBLIC kleidiai) endif() -endif() install( TARGETS torchao_kernels_aarch64 DESTINATION lib ) +endif() diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt index e4cafdc97a..7f97703588 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt +++ b/torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt @@ -40,8 +40,7 @@ endif() add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) -# The TORCHAO_BUILD_KLEIDI cmake variable should be set by `torchao_kernels_aarch64" -if(TORCHAO_BUILD_KLEIDI) +if(TORCHAO_BUILD_KLEIDIAI) add_compile_definitions(TORCHAO_ENABLE_KLEIDI) endif() diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh index 39cc76d887..2094c5df12 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh +++ b/torchao/experimental/kernels/cpu/aarch64/tests/build_and_run_tests.sh @@ -40,6 +40,7 @@ cmake \ ${EXTRA_ARGS} \ -DCMAKE_BUILD_TYPE=Debug \ -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \ + -DTORCHAO_BUILD_CPU_AARCH64=ON \ -S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/tests \ -B ${CMAKE_OUT} diff --git a/torchao/experimental/ops/embedding_xbit/CMakeLists.txt b/torchao/experimental/ops/embedding_xbit/CMakeLists.txt index 221b41074e..80c5bbc7be 100644 --- a/torchao/experimental/ops/embedding_xbit/CMakeLists.txt +++ b/torchao/experimental/ops/embedding_xbit/CMakeLists.txt @@ -13,7 +13,9 @@ add_library(torchao_ops_embedding_xbit_aten OBJECT op_embedding_xbit_aten.cpp ) target_link_torchao_parallel_backend(torchao_ops_embedding_xbit_aten "aten_openmp") -target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64) +if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE torchao_kernels_aarch64) +endif() target_include_directories(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_INCLUDE_DIRS}") target_link_libraries(torchao_ops_embedding_xbit_aten PRIVATE "${TORCH_LIBRARIES}") target_compile_definitions(torchao_ops_embedding_xbit_aten PRIVATE USE_ATEN=1) @@ -32,5 +34,7 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS) target_include_directories(torchao_ops_embedding_xbit_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}") target_compile_definitions(torchao_ops_embedding_xbit_executorch PRIVATE USE_EXECUTORCH=1) target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE "${EXECUTORCH_LIBRARIES}") - target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE torchao_kernels_aarch64) + if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries(torchao_ops_embedding_xbit_executorch PRIVATE torchao_kernels_aarch64) + endif() endif() diff --git a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h index 777ec740ca..bf3f9fb7bb 100644 --- a/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h +++ b/torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h @@ -6,9 +6,9 @@ #pragma once -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(TORCHAO_BUILD_CPU_AARCH64) #include -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // TORCHAO_BUILD_CPU_AARCH64 #include #include @@ -145,7 +145,7 @@ Tensor embedding_out_cpu( index = index64_ptr[idx]; } TORCHAO_CHECK(index >= 0 && index < num_embeddings, "index out of bounds"); -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(TORCHAO_BUILD_CPU_AARCH64) torchao::kernels::cpu::aarch64::embedding::embedding( out.mutable_data_ptr() + idx * embedding_dim, embedding_dim, @@ -157,7 +157,7 @@ Tensor embedding_out_cpu( index); #else TORCHAO_CHECK(false, "Unsupported platform"); -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // TORCHAO_BUILD_CPU_AARCH64 }); return out; @@ -234,7 +234,7 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) { header.write(out.mutable_data_ptr()); torchao::parallel_1d(0, num_embeddings, [&](int64_t idx) { -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(TORCHAO_BUILD_CPU_AARCH64) torchao::kernels::cpu::aarch64::embedding::pack_embedding_weight_qvals< weight_nbit>( out.mutable_data_ptr() + @@ -244,7 +244,7 @@ Tensor pack_embedding_cpu(const Tensor& weight_qvals) { idx); #else TORCHAO_CHECK(false, "Unsupported platform"); -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // defined(TORCHAO_BUILD_CPU_AARCH64) }); return out; diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt b/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt index 82d9fa2cf3..51f2718691 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/CMakeLists.txt @@ -18,13 +18,17 @@ FetchContent_Declare(cpuinfo FetchContent_MakeAvailable( cpuinfo) + find_package(Torch REQUIRED) add_library(torchao_ops_linear_8bit_act_xbit_weight_aten OBJECT linear_8bit_act_xbit_weight.cpp op_linear_8bit_act_xbit_weight_aten.cpp ) target_link_torchao_parallel_backend(torchao_ops_linear_8bit_act_xbit_weight_aten aten_openmp) -target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64) + +if(TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE torchao_kernels_aarch64) +endif() target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE cpuinfo) target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}") target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}") @@ -47,6 +51,8 @@ if(TORCHAO_BUILD_EXECUTORCH_OPS) target_include_directories(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_INCLUDE_DIRS}") target_compile_definitions(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE USE_EXECUTORCH=1) target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE "${EXECUTORCH_LIBRARIES}") - target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64) + if(TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE torchao_kernels_aarch64) + endif() target_link_libraries(torchao_ops_linear_8bit_act_xbit_weight_executorch PRIVATE cpuinfo) endif() diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h index 443d903dfb..c9fcd86bff 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/kernel_selector.h @@ -6,12 +6,13 @@ #pragma once #include +// #include #include #include -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(TORCHAO_BUILD_CPU_AARCH64) #include -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // TORCHAO_BUILD_CPU_AARCH64 #include #include @@ -132,7 +133,7 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable &table, torchao::ops::PackedWeightsType::linear_8bit_act_xbit_weight_universal); if (format.nr == 8 && format.kr == 16 && format.sr == 2) { -#if defined(__aarch64__) || defined(__ARM_NEON) +#if defined(TORCHAO_BUILD_CPU_AARCH64) if (cpuinfo_has_arm_neon_dot()) { namespace kernel = torchao::kernels::cpu::aarch64::linear:: channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot; @@ -159,7 +160,7 @@ void register_ukernel_config_universal(UKernelConfigRegistrationTable &table, has_clamp>}}}}); return; } -#endif // defined(__aarch64__) || defined(__ARM_NEON) +#endif // TORCHAO_BUILD_CPU_AARCH64 } } diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h index 364dd7b668..0e75d409b7 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight-impl.h @@ -6,10 +6,6 @@ #pragma once -#if defined(__aarch64__) || defined(__ARM_NEON) -#include -#endif // defined(__aarch64__) || defined(__ARM_NEON) - #include #include #include diff --git a/torchao/experimental/ops/tests/CMakeLists.txt b/torchao/experimental/ops/tests/CMakeLists.txt index c3d34d6ba9..8a9ad08f23 100644 --- a/torchao/experimental/ops/tests/CMakeLists.txt +++ b/torchao/experimental/ops/tests/CMakeLists.txt @@ -21,6 +21,11 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(googletest) enable_testing() + +if(TORCHAO_BUILD_CPU_AARCH64) + add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64=1) +endif() + if(TORCHAO_BUILD_KLEIDIAI) add_compile_definitions(TORCHAO_ENABLE_KLEIDI=1) endif() @@ -37,7 +42,11 @@ endif() include_directories(${TORCHAO_INCLUDE_DIRS}) set(TORCHAO_PARALLEL_BACKEND "test_dummy") -add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + +if (TORCHAO_BUILD_CPU_AARCH64) + add_subdirectory(${TORCHAO_ROOT}/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/torchao_kernels_aarch64) + add_compile_definitions(TORCHAO_BUILD_CPU_AARCH64) +endif() include(${TORCHAO_ROOT}/Utils.cmake) @@ -62,8 +71,14 @@ target_link_libraries( test_linear_8bit_act_xbit_weight PRIVATE GTest::gtest_main - torchao_kernels_aarch64 ) +if (TORCHAO_BUILD_CPU_AARCH64) + target_link_libraries( + test_linear_8bit_act_xbit_weight + PRIVATE + torchao_kernels_aarch64 + ) +endif() target_link_torchao_parallel_backend(test_linear_8bit_act_xbit_weight "${TORCHAO_PARALLEL_BACKEND}") include(GoogleTest) diff --git a/torchao/experimental/ops/tests/build_and_run_tests.sh b/torchao/experimental/ops/tests/build_and_run_tests.sh index cff7ca639a..6a73b91219 100644 --- a/torchao/experimental/ops/tests/build_and_run_tests.sh +++ b/torchao/experimental/ops/tests/build_and_run_tests.sh @@ -9,7 +9,7 @@ target=${1:-"native"} SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) export CMAKE_OUT=/tmp/cmake-out/torch_ao/tests -export TORCH_DIR = $(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib() + '/torch/share/cmake/Torch')") +export TORCH_DIR=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib() + '/torch/share/cmake/Torch')") IS_ARM64=0 BUILD_ARM_I8MM=0 @@ -45,6 +45,7 @@ fi cmake \ ${EXTRA_ARGS} \ -DCMAKE_BUILD_TYPE=Debug \ + -DTORCHAO_BUILD_CPU_AARCH64=${IS_ARM64} \ -DTORCHAO_BUILD_KLEIDIAI=${IS_ARM64} \ -DTORCHAO_BUILD_ARM_I8MM=${BUILD_ARM_I8MM} \ -DTorch_DIR=${TORCH_DIR} \ From e040fb6f8c2b6b6fe24574310199699de08429ac Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 4 Mar 2025 11:02:59 -0800 Subject: [PATCH 11/11] Update install instructions --- .github/workflows/torchao_experimental_test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/torchao_experimental_test.yml b/.github/workflows/torchao_experimental_test.yml index ba2cd800a6..60d8fa2718 100644 --- a/.github/workflows/torchao_experimental_test.yml +++ b/.github/workflows/torchao_experimental_test.yml @@ -33,7 +33,7 @@ jobs: - name: Install requirements run: | conda activate venv - pip install torch --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" + pip install torch --index-url "https://download.pytorch.org/whl/nightly/cpu" pip install numpy pip install pytest USE_CPP=1 pip install .