Skip to content
Merged
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,6 @@
[submodule "third-party/pybind11"]
path = third-party/pybind11
url = https://github.com/pybind/pybind11.git
[submodule "third-party/ao"]
path = third-party/ao
url = https://github.com/pytorch/ao.git
9 changes: 9 additions & 0 deletions examples/models/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ cmake_dependent_option(
"NOT EXECUTORCH_BUILD_ARM_BAREMETAL" OFF
)

option(EXECUTORCH_BUILD_TORCHAO "Build the torchao kernels" OFF)

if(NOT PYTHON_EXECUTABLE)
set(PYTHON_EXECUTABLE python3)
endif()
Expand Down Expand Up @@ -121,6 +123,13 @@ if(EXECUTORCH_BUILD_KERNELS_CUSTOM)
list(APPEND link_libraries custom_ops)
endif()

if(EXECUTORCH_BUILD_TORCHAO)
set(TORCHAO_BUILD_EXECUTORCH_OPS ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../../third-party/ao/torchao/experimental ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/ao/torchao/experimental)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do ${EXECUTORCH_ROOT}/third-party/ao/torchao/experimental

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But one is the source directory and the other the binary directory?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I meant for the first one we can make it simpler. But it's up to you

target_link_options_shared_lib(torchao_ops_executorch)
list(APPEND link_libraries torchao_ops_executorch)
endif()

set(XNNPACK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../backends/xnnpack)
# Extra compile option and include dir for pthreadpool
if(EXECUTORCH_BUILD_PTHREADPOOL)
Expand Down
36 changes: 33 additions & 3 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import copy
import json
import logging
import re
import shlex
from enum import Enum
from json import JSONDecodeError
from pathlib import Path
from typing import Callable, List, Optional, Union

import pkg_resources

import torch

from executorch.devtools.etrecord import generate_etrecord
Expand Down Expand Up @@ -153,12 +153,12 @@ def build_args_parser() -> argparse.ArgumentParser:
],
help="Use PT2E quantization. Comma separated options. e.g. xnnpack_dynamic (for per channel 8 bit weight), xnnpack_dynamic_qc4 (for per channel 4 bit weight), embedding.",
)

parser.add_argument(
"-qmode",
"--quantization_mode",
type=str,
type=_qmode_type,
default=None,
choices=["int8", "8da4w", "8da4w-gptq", "vulkan_4w"],
help="type of quantization",
)

Expand Down Expand Up @@ -568,6 +568,23 @@ def get_quantizer_and_quant_params(args):
return pt2e_quant_params, quantizers, quant_dtype


def _qmode_type(value):
choices = ["int8", "8da4w", "8da4w-gptq", "vulkan_4w"]
patterns = [r"torchao:8da(\d+)w"]

if value in choices:
return value

for pattern in patterns:
matches = re.findall(pattern, value)
if len(matches) == 1:
return value

raise argparse.ArgumentTypeError(
f"Got qmode {value}, but expected one of {choices}, or one of the regex patterns {patterns}."
)


def _validate_args(args):
"""
TODO: Combine all the backends under --backend args
Expand All @@ -581,6 +598,19 @@ def _validate_args(args):
if args.num_sharding > 0 and not args.qnn:
raise ValueError("Model shard is only supported with qnn backend now.")

if (
args.quantization_mode is not None
and args.quantization_mode.startswith("torchao:")
) or (
args.embedding_quantize is not None
and args.embedding_quantize.startswith("torchao:")
):
if args.enable_dynamic_shape:
raise ValueError(
"Dynamic shape is not currently supported with torchao ops. Please use --disable_dynamic_shape."
"If you need this feature, please file an issue."
)


def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901
_validate_args(args)
Expand Down
3 changes: 1 addition & 2 deletions examples/models/llama/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
pip install snakeviz sentencepiece

# Install torchao.
TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt)
pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}"
pip install "$(dirname "$0")/../../../third-party/ao"

# Install lm-eval for Model Evaluation with lm-evalution-harness
# Install tiktoken for tokenizer
Expand Down
61 changes: 61 additions & 0 deletions examples/models/llama/source_transformation/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import re
from functools import partial
from pathlib import Path
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -70,6 +72,26 @@ def quantize( # noqa C901
if qmode == "int8":
# Add quantization mode options here: group size, bit width, etc.
return WeightOnlyInt8QuantHandler(model).quantized_model()
elif qmode.startswith("torchao:"):
pattern = r"torchao:8da(\d+)w"
matches = re.findall(pattern, qmode)
assert len(matches) == 1, f"Expected 1 match for pattern but got {len(matches)}"
bitwidth = int(matches[0][0])
_load_torchao_ops_aten()
from torchao.experimental.quant_api import Int8DynActIntxWeightLinearQuantizer

with torch.no_grad():
model = Int8DynActIntxWeightLinearQuantizer(
device="cpu",
precision=torch.float32,
groupsize=group_size,
bitwidth=bitwidth,
has_weight_zeros=False,
).quantize(model)

if verbose:
print("quantized model:", model)
return model
elif qmode == "8da4w":
# Check for required args
if group_size is None:
Expand All @@ -79,6 +101,7 @@ def quantize( # noqa C901
model = Int8DynActInt4WeightQuantizer(
precision=torch_dtype, groupsize=group_size
).quantize(model)

if verbose:
print("quantized model:", model)
return model
Expand Down Expand Up @@ -692,6 +715,25 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor:


def get_quant_embedding_transform(args):
if args.embedding_quantize.startswith("torchao:"):
bitwidth, group_size = args.embedding_quantize.split(":")[1].split(",")
group_size = int(group_size)
bitwidth = int(bitwidth)
_load_torchao_ops_aten()
from torchao.experimental.quant_api import IntxWeightEmbeddingQuantizer

def _torchao_embedding_quantizer(model):
with torch.no_grad():
model = IntxWeightEmbeddingQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=bitwidth,
groupsize=group_size,
).quantize(model)
return model

return _torchao_embedding_quantizer

bitwidth, group_size = args.embedding_quantize.split(",")
if group_size == "none" or group_size == "None" or group_size == "0":
group_size = None
Expand Down Expand Up @@ -733,4 +775,23 @@ def get_quant_weight_transform(args, dtype_override, verbose):
)


def _load_torchao_ops_aten():
import glob
import os

libs = glob.glob(
os.path.abspath(
os.path.join(
os.environ.get("CMAKE_INSTALL_PREFIX", ""),
"lib/libtorchao_ops_aten.*",
)
)
)
assert (
len(libs) == 1
), f"Expected 1 library but got {len(libs)}. If you installed the torchao ops in a non-standard location, please set CMAKE_INSTALL_PREFIX correctly."
logging.info(f"Loading custom ops library: {libs[0]}")
torch.ops.load_library(libs[0])


############################ Source Transform End #######################
3 changes: 1 addition & 2 deletions examples/models/llama3_2_vision/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@
pip install --pre torchtune --extra-index-url https://download.pytorch.org/whl/nightly/cpu --no-cache-dir

# Install torchao.
TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt)
pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}"
pip install "$(dirname "$0")/../../../third-party/ao"
3 changes: 1 addition & 2 deletions examples/models/phi-3-mini-lora/install_requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,4 @@ pip install torchtune
pip install tiktoken

# Install torchao.
TORCHAO_VERSION=$(cat "$(dirname "$0")"/../../../.ci/docker/ci_commit_pins/torchao.txt)
pip install --no-use-pep517 "git+https://github.com/pytorch/ao.git@${TORCHAO_VERSION}"
pip install "$(dirname "$0")/../../../third-party/ao"
1 change: 1 addition & 0 deletions third-party/ao
Submodule ao added at 75d069
Loading