Skip to content

Commit 7740c6d

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Cleanup gbps metrics
Summary: gbps is a user-custom metric, not built-in metric, we are unifying its API here. Reviewed By: FindHao, adamomainz Differential Revision: D78163498 fbshipit-source-id: 3bdca98094b2ad1b2e1041cfb593c538562a8dc8
1 parent 4821d26 commit 7740c6d

File tree

12 files changed

+12
-24
lines changed

12 files changed

+12
-24
lines changed

tritonbench/operators/gather_gemv/operator.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,9 @@
55
"""
66

77
import argparse
8-
import csv
9-
import os
10-
import statistics
11-
from typing import Any, Callable, Generator, List, Optional
8+
from typing import Callable, Generator, List, Optional
129

13-
import numpy
1410
import torch
15-
import triton
1611
from torch._dynamo.testing import rand_strided
1712

1813
from tritonbench.utils.triton_op import (
@@ -27,7 +22,7 @@
2722

2823
class Operator(BenchmarkOperator):
2924
@register_metric()
30-
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
25+
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics):
3126
arg0_1, arg1_1, arg2_1 = example_inputs
3227
return (
3328
2

tritonbench/operators/jagged_layer_norm/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
180180
)
181181

182182
@register_metric(skip_baseline=True)
183-
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
183+
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics):
184184
return (
185185
example_inputs[0].element_size()
186186
* example_inputs[0].numel()

tritonbench/operators/jagged_mean/operator.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
import argparse
2-
import itertools
3-
import math
42
import os
5-
import random
63
from typing import Callable, Generator, List, Optional, Tuple
74

85
import torch
@@ -229,7 +226,7 @@ def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
229226
)
230227

231228
@register_metric()
232-
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
229+
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics):
233230
return (
234231
example_inputs[0].element_size()
235232
* example_inputs[0].numel()

tritonbench/operators/jagged_softmax/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
213213
)
214214

215215
@register_metric(skip_baseline=True)
216-
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
216+
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics):
217217
return (
218218
example_inputs[0].element_size()
219219
* example_inputs[0].numel()

tritonbench/operators/jagged_sum/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
227227
)
228228

229229
@register_metric()
230-
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
230+
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics):
231231
return (
232232
example_inputs[0].element_size()
233233
* example_inputs[0].numel()

tritonbench/operators/layer_norm/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def get_x_val(self, args):
8888
return N
8989

9090
@register_metric()
91-
def gbps(self, fn_name, args, metrics: BenchmarkOperatorMetrics) -> float:
91+
def gbps(self, fn, args, metrics: BenchmarkOperatorMetrics) -> float:
9292
x = args[0]
9393
base = x.numel() * x.element_size() / metrics.latency * 1e-6
9494
return {

tritonbench/operators/low_mem_dropout/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
class Operator(BenchmarkOperator):
1717
@register_metric()
18-
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
18+
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics):
1919
return (
2020
3
2121
* example_inputs[1].element_size()

tritonbench/operators/softmax/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def get_x_val(self, example_inputs):
122122
return [shape[0], shape[1]]
123123

124124
@register_metric()
125-
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics) -> float:
125+
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics) -> float:
126126
return (
127127
2
128128
* example_inputs[0].nelement()

tritonbench/operators/sum/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def _get_accuracy(self, fn: Callable, baseline_fn: Callable) -> bool:
295295
)
296296

297297
@register_metric()
298-
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
298+
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics):
299299
return (
300300
example_inputs[0].element_size()
301301
* example_inputs[0].numel()

tritonbench/operators/vector_add/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class Operator(BenchmarkOperator):
1717
DEFAULT_METRICS = ["latency", "gbps"]
1818

1919
@register_metric()
20-
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
20+
def gbps(self, fn, example_inputs, metrics: BenchmarkOperatorMetrics):
2121
def normalize(lat):
2222
return (
2323
3

0 commit comments

Comments
 (0)