Skip to content

Commit 92c4bda

Browse files
authored
Moving optimizations and int_mm_fused kernel change to main (#60)
1 parent 786153b commit 92c4bda

File tree

4 files changed

+24
-257
lines changed

4 files changed

+24
-257
lines changed

experiments/eval_combo.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
import segment_anything_fast
88

99
torch._dynamo.config.cache_size_limit = 50000
10-
10+
# torch._inductor.config.fx_graph_cache = True # seems to slow performance
11+
torch._inductor.config.epilogue_fusion = False
12+
torch._inductor.config.coordinate_descent_tuning = True
13+
torch._inductor.config.coordinate_descent_check_all_directions = True
14+
torch._inductor.config.force_fuse_int_mm_with_mul = True
1115

1216
def unbind_jagged(device, data, sizes, offsets):
1317
if data is None:
@@ -193,7 +197,7 @@ def build_results(batched_data_iter,
193197
if batch_idx == 0:
194198
with torch.autograd.profiler.record_function("compilation and warmup"):
195199
if str(use_compile) != "False":
196-
predictor.model.image_encoder = torch.compile(predictor.model.image_encoder, mode=use_compile)
200+
predictor.model.image_encoder = torch.compile(predictor.model.image_encoder, mode=use_compile, fullgraph=True,)
197201
# Run first batch a few times for warmup and exclude it from the final timings
198202
for _ in range(3):
199203
_ = batch_runner(predictor, batch, batch_size, pad_input_image_batch)
Binary file not shown.

segment_anything_fast/dynamic_quant.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import torch
77
from torch._dynamo import is_compiling as dynamo_is_compiling
8-
from segment_anything_fast.int_mm import _int_mm_dequant
8+
from torch._higher_order_ops.out_dtype import out_dtype
99

1010
def quantize_activation_per_token(t, scales):
1111
t = torch.round(t / scales).clamp(-127, 127).to(torch.int8)
@@ -14,10 +14,13 @@ def quantize_activation_per_token(t, scales):
1414
def quantize_activation_per_token_absmax(t):
1515
n_bits = 8
1616
# if the shape of t is [B, N, K], the shape of scales will be [B, N, 1]
17-
# want float scales to avoid overflows
18-
scales = t.abs().max(dim=-1, keepdim=True)[0].float()
17+
18+
# dequant with fp16 scale can cause overflow
19+
# otherwise leave scales as same dtype
20+
if t.dtype == torch.float16:
21+
t=t.float()
1922
q_max = 2 ** (n_bits - 1) - 1
20-
scales.clamp_(min=1e-5).div_(q_max)
23+
scales = t.abs().max(dim=-1, keepdim=True)[0].clamp(min=1e-5).div(q_max)
2124
# Note: the original smoothquant does not clamp to qmin/qmax here,
2225
# but some of the tests with bfloat16 ended up with a flipped sign
2326
# if we don't clamp. TODO(future) look into this further.
@@ -58,18 +61,18 @@ def quant_int8_dynamic_per_token_linear(
5861
w_vals_int8_t,
5962
w_scales,
6063
bias,
61-
out_dtype=torch.float32,
64+
output_dtype=torch.float32,
6265
):
6366

6467
def quant_int8_per_token_matmul(
6568
x_vals_int8,
6669
x_scales,
6770
w_vals_int8_t,
6871
w_scales,
69-
out_dtype=torch.float32,
72+
output_dtype=torch.float32,
7073
):
7174
# Quantized matmul of int8 operands that accumulates to int32 and returns
72-
# out_dtype. For now, this is written for approximate numerical
75+
# output_dtype. For now, this is written for approximate numerical
7376
# Assumes that activation and weight quantization are symmetric,
7477
# i.e. act_zp and w_zp is 0.
7578
# Assumes that weight quantization is per-channel.
@@ -78,7 +81,7 @@ def quant_int8_per_token_matmul(
7881
# https://github.com/google/gemmlowp/blob/master/doc/quantization.md
7982
# for an overview of quantized matmul compute
8083

81-
# in scalar form, assuming out_dtype is fp32 and zw == 0:
84+
# in scalar form, assuming output_dtype is fp32 and zw == 0:
8285
#
8386
# Y_i_j_fp32 = sx * sw dot(X_i, W_j)
8487
#
@@ -87,19 +90,16 @@ def quant_int8_per_token_matmul(
8790
f'x dtype {x_vals_int8.dtype} not yet supported'
8891
assert w_vals_int8_t.dtype == torch.int8, \
8992
f'w dtype {w_vals_int8_t.dtype} not yet supported'
90-
assert w_scales.dtype == out_dtype, \
91-
f'{w_scales.dtype} does not match {out_dtype}'
93+
assert w_scales.dtype == output_dtype, \
94+
f'{w_scales.dtype} does not match {output_dtype}'
9295

9396
#
9497
# 1. do the matrix form of dot(X_i, W_j)
9598
#
9699

97100

98-
# TODO(before land): add test case for input with bsz
99101
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
100-
x_scales_flat = x_scales.view(tmp.size(0), 1)
101-
w_scales_flat = w_scales.unsqueeze(0)
102-
# y_dot_int32 = torch._int_mm(tmp, w_vals_int8_t)
102+
y_dot_int32 = out_dtype(torch.ops.aten.mm.default, torch.int32, tmp, w_vals_int8_t)
103103

104104
#
105105
# 2. rescale the output
@@ -108,21 +108,16 @@ def quant_int8_per_token_matmul(
108108
# large that y_dot_int32 * a float16 scale is greater than the maximum
109109
# value of a float 16, (which results in a value of inf even if multiplying
110110
# by the other scale would bring it within the expected range)
111+
assert x_scales.dtype != torch.float16, f"can have overflows happen when x_scales is float16"
111112

112-
assert x_scales.dtype == torch.float, f"x_scales needs to be a torch.float32 but got {x_scales.dtype}"
113-
114-
# y = y_dot_int32 * x_scales_flat * w_scales_flat
115-
# # can downcast only at the very end
116-
# y = y.to(out_dtype)
113+
y = (y_dot_int32 * x_scales.view(-1, 1) * w_scales)
117114

118-
# y = _int_mm_dequant(tmp, w_vals_int8_t, x_scales_flat, w_scales_flat, out_dtype)
119-
y = torch.ops.custom_int_mm.int_mm_dequant(tmp, w_vals_int8_t, x_scales_flat, w_scales_flat, out_dtype)
120-
return y.reshape(*x_vals_int8.shape[:-1], -1)
115+
return y.reshape(*x_vals_int8.shape[:-1], -1).to(output_dtype)
121116

122117
# like F.linear, but with int8 dynamic quantization of activation,
123118
# and a quantized weight
124119
mm_out = quant_int8_per_token_matmul(
125-
x_vals_int8, x_scales, w_vals_int8_t, w_scales, out_dtype)
120+
x_vals_int8, x_scales, w_vals_int8_t, w_scales, output_dtype)
126121
if bias is not None:
127122
mm_out += bias
128123
return mm_out

segment_anything_fast/int_mm.py

Lines changed: 0 additions & 232 deletions
This file was deleted.

0 commit comments

Comments
 (0)