-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Open
Labels
performanceissues related to performance regressionsissues related to performance regressions
Description
Describe the issue
Describe the issue
BatchNorm1d exported to ONNX is extremely slow on CPU compared to BatchNorm2d and naive implementation.
Fast fix is to use BatchNorm2d but with reshaping for mathematical equivalence.
Performance observation
BatchNorm1D:
Latency (ms): Min=7.931, Max=12.540
Mean=8.081 ± 0.572 std
Percentiles: P50=7.989, P90=8.030
Throughput: 123.8 req/s
BatchNorm2D with transpose and reshaping:
Latency (ms): Min=0.384, Max=2.275
Mean=0.404 ± 0.062 std
Percentiles: P50=0.389, P90=0.401
Throughput: 2477.8 req/s
To reproduce
import torch
from torch import nn
import inspect
class Dummy1(nn.Module):
def __init__(
self
):
super().__init__()
self.forward_args = list(inspect.signature(self.forward).parameters.keys())
self.bn = nn.BatchNorm1d(64, track_running_stats=False)
def forward(
self,
features: torch.Tensor, # (req, emb_dim)
) -> torch.Tensor:
features = self.bn(features)
return features
class Dummy2(nn.Module):
def __init__(
self
):
super().__init__()
self.forward_args = list(inspect.signature(self.forward).parameters.keys())
self.bn = nn.BatchNorm2d(64, track_running_stats=False)
def forward(
self,
features: torch.Tensor, # (batch_size, emb_dim)
) -> torch.Tensor:
features = self.bn(features.T.unsqueeze(0).unsqueeze(-1)).squeeze().T
# candidates_embeddings = self.bn(candidates_embeddings.reshape(1,64,2000,1)).squeeze().T
return features
dum = Dummy1()
dum2 = Dummy2()
for i in range(100):
feats = torch.randn(2000,64)
out1 = dum(feats)
out2 = dum2(feats)
assert torch.allclose(out1,out2)
print("All checks are passed!")
import time
import numpy as np
import onnxruntime as ort
import psutil
import gc
import os
def make_model_and_convert(model_class, model_name):
fusion = model_class()
data = {
"features": torch.randn(2000,64),
}
for i in range(10):
test_data = {k:v for k,v in data.items() if k in fusion.forward_args}
fusion(**test_data)
fusion.eval()
torch.onnx.export(
model=fusion,
args=data, # model input (or a tuple for multiple inputs)
f=f"temp/{model_name}.onnx", # where to save the model (can be a file or file-like object)
verbose=True,
export_params=True, # store the trained parameter weights inside the model file
opset_version=15, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=fusion.forward_args, # the model's input names
output_names=None, # the model's output names
)
def benchmark_onnx(model_path, data, num_runs=500, warmup_runs=10, threads=1):
os.environ["OMP_NUM_THREADS"] = str(threads)
os.environ["KMP_BLOCKTIME"] = "1"
providers = ['CPUExecutionProvider']
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = threads
sess_options.inter_op_num_threads = 1
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
session = ort.InferenceSession(model_path, sess_options, providers=providers)
input_name = session.get_inputs()[0].name
input_names = [x.name for x in session.get_inputs()]
input_data = {k: ort.OrtValue.ortvalue_from_numpy(np.array(v)) for k, v in data.items() if k in input_names}
gc.disable()
for _ in range(warmup_runs):
session.run(None, input_data)
latencies = []
process = psutil.Process(os.getpid())
process.cpu_affinity([0])
for _ in range(num_runs):
start_time = time.perf_counter_ns()
session.run(None, input_data)
end_time = time.perf_counter_ns()
latencies.append((end_time - start_time) / 1e6) # ms
gc.enable()
latencies = np.array(latencies)
print(f"\n{'='*40} RESULTS {'='*40}")
print(f"Warmup: {warmup_runs} | Iterations: {num_runs}")
print(f"Latency (ms): Min={np.min(latencies):.3f}, Max={np.max(latencies):.3f}")
print(f"Mean={np.mean(latencies):.3f} ± {np.std(latencies):.3f} std")
print(f"Percentiles: P50={np.percentile(latencies, 50):.3f}, P90={np.percentile(latencies, 90):.3f}")
print(f"Throughput: {1000/np.mean(latencies):.1f} req/s")
print("="*90)
return
data = {
"features": torch.randn(2000,64),
}
make_model_and_convert(Dummy1, "dummy")
make_model_and_convert(Dummy2, "dummy2")
model_path = "temp/dummy.onnx"
benchmark_onnx(
model_path,
data,
num_runs=2000,
warmup_runs=1000,
)
model_path = "temp/dummy2.onnx"
benchmark_onnx(
model_path,
data,
num_runs=2000,
warmup_runs=1000,
)
Urgency
Not urgent
Platform
Linux
OS Version
Ubuntu 22.04.3 LTS
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.20.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
Default CPU
Execution Provider Library Version
No response
Model File
No response
Is this a quantized model?
No
artyomrabosh
Metadata
Metadata
Assignees
Labels
performanceissues related to performance regressionsissues related to performance regressions