Skip to content

Commit b519f89

Browse files
xuzhao9facebook-github-bot
authored andcommitted
format code and fix bugs from input loader (#272)
Summary: The Aten op diff broke OSS CI, fixing it. Remove the old `operator_loader` files. Pull Request resolved: #272 Reviewed By: FindHao Differential Revision: D78094345 Pulled By: xuzhao9 fbshipit-source-id: 23c8a6fd311bfc654de9502c067dd3fdccac8152
1 parent 90977e9 commit b519f89

File tree

12 files changed

+39
-753
lines changed

12 files changed

+39
-753
lines changed

docker/tritonbench-nightly.dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ RUN echo "\
5151
conda activate base\n\
5252
export CONDA_HOME=/workspace/miniconda3\n\
5353
export CUDA_HOME=/usr/local/cuda\n\
54-
export PATH=/home/runner/bin\${PATH:+:\${PATH}}\n\
54+
export PATH=\${CUDA_HOME}/bin:/home/runner/bin\${PATH:+:\${PATH}}\n\
5555
export LD_LIBRARY_PATH=\${CUDA_HOME}/lib64\${LD_LIBRARY_PATH:+:\${LD_LIBRARY_PATH}}\n\
5656
export LIBRARY_PATH=\${CUDA_HOME}/lib64\${LIBRARY_PATHPATH:+:\${LIBRARY_PATHPATH}}\n" >> /workspace/setup_instance.sh
5757

test/test_gpu/main.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@
4444
# Ops that require special arguments in backwards
4545
BWD_ARGS_OPS: Dict[str, List[str]] = skip_tests.get("bwd_args", {})
4646

47-
TEST_OPERATORS = set(list_operators_by_collection(op_collection="buck"))
47+
TEST_OPERATORS = (
48+
set(list_operators_by_collection(op_collection="buck"))
49+
if is_fbcode()
50+
else set(list_operators_by_collection(op_collection="default"))
51+
)
4852

4953

5054
def check_ci_output(op):

test/test_gpu/skip_tests_h100_pytorch.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ fp8_fused_quant_gemm_rowwise:
2626
gemm:
2727
# internal only kernels
2828
- hstu_triton_matmul
29-
# TODO: fix PT2 cutlass
29+
# pt2 cutlass kernel
3030
- pt2_cutlass_matmul
3131
# jagged tests are slow, so disable them in OSS
3232
jagged_layer_norm:

tritonbench/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +0,0 @@
1-
from .operators import list_operators, load_opbench_by_name
2-
from .operators_collection import (
3-
list_operator_collections,
4-
list_operators_by_collection,
5-
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .utils import is_loader_op, get_op_loader_bench_cls_by_name, list_loader_operators
1+
from .utils import get_op_loader_bench_cls_by_name, is_loader_op, list_loader_operators
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .loader import list_aten_ops, get_aten_loader_cls_by_name
1+
from .loader import get_aten_loader_cls_by_name, list_aten_ops

tritonbench/operator_loader/aten/loader.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
11
import argparse
22
import os
3+
import types
4+
from typing import Dict, Generator, List, Optional
5+
36
import torch
47
import yaml
5-
import types
6-
from typing import List, Optional, Generator, Dict
78
from tritonbench.utils.triton_op import BenchmarkOperator, register_benchmark
89

910
# The config file defines available ATen operators and their corresponding input shapes.
1011
ATEN_CONFIG_YAML = os.path.join(os.path.dirname(__file__), "config.yaml")
1112
aten = torch.ops.aten
1213

13-
class AtenOperator(BenchmarkOperator):
1414

15-
def __init__(self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None):
15+
class AtenOperator(BenchmarkOperator):
16+
def __init__(
17+
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
18+
):
1619
super().__init__(tb_args, extra_args)
1720
self.aten_op = eval(self.aten_op_name)
1821
if not self.tb_args.input_loader:
@@ -42,7 +45,8 @@ def get_default_aten_op_input(aten_op_name: str) -> str:
4245
config = yaml.safe_load(f)
4346
return config[aten_op_name]
4447

45-
def get_aten_loader_cls_by_name(aten_op_name: str, aten_op_input: Optional[str]=None):
48+
49+
def get_aten_loader_cls_by_name(aten_op_name: str, aten_op_input: Optional[str] = None):
4650
"""
4751
Return a class generated from the given aten op name and input.
4852
If input is not provided, use the default input from the config file.
@@ -55,12 +59,19 @@ def get_aten_loader_cls_by_name(aten_op_name: str, aten_op_input: Optional[str]=
5559
op_name_module.Operator = op_class
5660
op_class.name = op_cls_name
5761
op_class.aten_op_name = aten_op_name
58-
op_class.aten_op_input = aten_op_input if aten_op_input else get_default_aten_op_input(aten_op_name)
62+
op_class.aten_op_input = (
63+
aten_op_input if aten_op_input else get_default_aten_op_input(aten_op_name)
64+
)
5965
# register two backends for each aten op: eager and inductor
60-
register_benchmark(operator_name=op_cls_name, func_name="eager", baseline=True)(op_class.eager)
61-
register_benchmark(operator_name=op_cls_name, func_name="inductor", baseline=False)(op_class.inductor)
66+
register_benchmark(operator_name=op_cls_name, func_name="eager", baseline=True)(
67+
op_class.eager
68+
)
69+
register_benchmark(operator_name=op_cls_name, func_name="inductor", baseline=False)(
70+
op_class.inductor
71+
)
6272
return op_class
6373

74+
6475
def list_aten_ops() -> Dict[str, str]:
6576
"""
6677
Load all ATen operators from the config file.

0 commit comments

Comments
 (0)