Skip to content

Commit 813e0b8

Browse files
vanbasten23rahul-tuli
authored andcommitted
Use xla flag to improve the quantized model performance (#19303)
Signed-off-by: Xiongfei Wei <isaacwxf23@gmail.com> Signed-off-by: Rahul Tuli <rahul@neuralmagic.com>
1 parent 4589b94 commit 813e0b8

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

vllm/v1/worker/tpu_worker.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ def init_device(self):
101101
# fix this. It will be removed after the bug in XLA compiler is fixed.
102102
os.environ["LIBTPU_INIT_ARGS"] = (
103103
os.environ.get("LIBTPU_INIT_ARGS", "") +
104-
" --xla_tpu_force_1d_allreduce_at_chunk_count=1")
104+
" --xla_tpu_force_1d_allreduce_at_chunk_count=1"
105+
" --xla_jf_conv_input_fusion=False")
106+
# --xla_jf_conv_input_fusion=False is used to improve the perf of
107+
# quantized matmul.
105108
torch.set_grad_enabled(False)
106109
torch.set_default_dtype(self.model_config.dtype)
107110

0 commit comments

Comments
 (0)