Skip to content

Commit 6a69763

Browse files
authored
build: RHEL8 EA2 Backends (#7568)
* build: RHEL8 EA2 Backends
1 parent 187a4a3 commit 6a69763

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

build.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def fail_if(p, msg):
116116

117117

118118
def target_platform():
119-
if FLAGS.target_platform is not None:
119+
# When called by compose.py, FLAGS will be None
120+
if FLAGS and FLAGS.target_platform is not None:
120121
return FLAGS.target_platform
121122
platform_string = platform.system().lower()
122123
if platform_string == "linux":
@@ -132,7 +133,8 @@ def target_platform():
132133

133134

134135
def target_machine():
135-
if FLAGS.target_machine is not None:
136+
# When called by compose.py, FLAGS will be None
137+
if FLAGS and FLAGS.target_machine is not None:
136138
return FLAGS.target_machine
137139
return platform.machine().lower()
138140

@@ -639,13 +641,16 @@ def pytorch_cmake_args(images):
639641
cmake_backend_arg("pytorch", "TRITON_PYTORCH_DOCKER_IMAGE", None, image),
640642
]
641643

642-
if FLAGS.enable_gpu:
644+
# TODO: TPRD-372 TorchTRT extension is not currently supported by our manylinux build
645+
# TODO: TPRD-373 NVTX extension is not currently supported by our manylinux build
646+
if target_platform() != "rhel":
647+
if FLAGS.enable_gpu:
648+
cargs.append(
649+
cmake_backend_enable("pytorch", "TRITON_PYTORCH_ENABLE_TORCHTRT", True)
650+
)
643651
cargs.append(
644-
cmake_backend_enable("pytorch", "TRITON_PYTORCH_ENABLE_TORCHTRT", True)
652+
cmake_backend_enable("pytorch", "TRITON_ENABLE_NVTX", FLAGS.enable_nvtx)
645653
)
646-
cargs.append(
647-
cmake_backend_enable("pytorch", "TRITON_ENABLE_NVTX", FLAGS.enable_nvtx)
648-
)
649654
return cargs
650655

651656

@@ -1301,7 +1306,6 @@ def dockerfile_prepare_container_linux(argmap, backends, enable_gpu, target_mach
13011306
gpu_enabled=gpu_enabled
13021307
)
13031308

1304-
# This
13051309
if target_platform() == "rhel":
13061310
df += """
13071311
# Common dpeendencies.

0 commit comments

Comments
 (0)