Skip to content

Commit 6d2a3c2

Browse files
committed
change line-length to 100
1 parent b518889 commit 6d2a3c2

11 files changed

+35
-105
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ test = [
1818
include = ["tritonparse*"]
1919

2020
[tool.black]
21-
line-length = 88
21+
line-length = 100
2222
target-version = ["py310"]
2323

2424
[tool.ufmt]

tests/test_add.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,4 @@ def test_tensor_add():
7272
if __name__ == "__main__":
7373
test_tensor_add()
7474
# Use improved unified_parse with explicit output directory
75-
tritonparse.utils.unified_parse(
76-
source=log_path, out="./parsed_output", overwrite=True
77-
)
75+
tritonparse.utils.unified_parse(source=log_path, out="./parsed_output", overwrite=True)

tests/test_tritonparse.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,14 @@ def run_test_kernel(x):
211211
torch.cuda.synchronize()
212212

213213
# Check that temp_dir_logs folder has content
214-
assert os.path.exists(
215-
temp_dir_logs
216-
), f"Log directory {temp_dir_logs} does not exist."
214+
assert os.path.exists(temp_dir_logs), f"Log directory {temp_dir_logs} does not exist."
217215
log_files = os.listdir(temp_dir_logs)
218216
assert (
219217
len(log_files) > 0
220218
), f"No log files found in {temp_dir_logs}. Expected log files to be generated during Triton compilation."
221219
print(f"Found {len(log_files)} log files in {temp_dir_logs}: {log_files}")
222220

223-
tritonparse.utils.unified_parse(
224-
source=temp_dir_logs, out=temp_dir_parsed, overwrite=True
225-
)
221+
tritonparse.utils.unified_parse(source=temp_dir_logs, out=temp_dir_parsed, overwrite=True)
226222

227223
# Clean up temporary directory
228224
try:

tritonparse/common.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,7 @@ def copy_local_to_tmpdir(local_path: str, verbose: bool = False) -> str:
247247

248248
for item in os.listdir(local_path):
249249
item_path = os.path.join(local_path, item)
250-
if os.path.isfile(item_path) and os.path.basename(item_path).startswith(
251-
LOG_PREFIX
252-
):
250+
if os.path.isfile(item_path) and os.path.basename(item_path).startswith(LOG_PREFIX):
253251
if verbose:
254252
logger.info(f"Copying {item_path} to {temp_dir}")
255253
shutil.copy2(item_path, temp_dir)
@@ -309,9 +307,7 @@ def parse_logs(
309307
for rank, files in ranks.items():
310308
use_filenames = False
311309
if len(files) > 1:
312-
logger.warning(
313-
"Warning: multiple logs found for the same rank. Using filenames."
314-
)
310+
logger.warning("Warning: multiple logs found for the same rank. Using filenames.")
315311
use_filenames = True
316312
# Determine rank key for file mapping
317313
rank_key = "rank_default" if rank.is_default else f"rank_{rank.value}"
@@ -350,9 +346,7 @@ def parse_logs(
350346
# Add files to the mapping (now with .gz extensions)
351347
file_mapping[rank_key]["regular_files"].extend(generated_files)
352348
# this is used to generate the tritonparse url
353-
file_mapping[rank_key]["rank_suffix"] = rank_config.to_rank().to_string(
354-
suffix="/"
355-
)
349+
file_mapping[rank_key]["rank_suffix"] = rank_config.to_rank().to_string(suffix="/")
356350
if mapped_file:
357351
file_mapping[rank_key]["mapped_file"] = mapped_file
358352

tritonparse/extract_source_mappings.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,12 @@
4040

4141
# the definition of the PTX loc directive.
4242
# Example: .loc 1 0 50 // abcdef.py:0:50
43-
PTX_LOC_PATTERN = re.compile(
44-
r"^\s*\.loc\s+\d+\s+(\d+)\s+(\d+)\s+//\s*(.+?):(\d+):(\d+)"
45-
)
43+
PTX_LOC_PATTERN = re.compile(r"^\s*\.loc\s+\d+\s+(\d+)\s+(\d+)\s+//\s*(.+?):(\d+):(\d+)")
4644

4745
# the definition of the AMDGCN loc directive.
4846
# Example: .loc 1 32 30 ; abcd.py:32:30
4947
# .loc 1 32 46 is_stmt 0 ; abcd.py:32:46
50-
AMDGCN_LOC_PATTERN = re.compile(
51-
r".*loc\s+(\d+)\s+(\d+)\s+(\d+)(?:\s+[^;]*)?;\s*(.+?):(\d+):(\d+)"
52-
)
48+
AMDGCN_LOC_PATTERN = re.compile(r".*loc\s+(\d+)\s+(\d+)\s+(\d+)(?:\s+[^;]*)?;\s*(.+?):(\d+):(\d+)")
5349

5450

5551
def get_file_extension(filename: str) -> str:
@@ -257,9 +253,7 @@ def extract_ptx_amdgcn_mappings(
257253
def get_file_path(filename: str) -> str:
258254
file_path = filename
259255
if not os.path.isabs(filename):
260-
logger.debug(
261-
f"Filename '{filename}' does not contain a path. Attempting to resolve."
262-
)
256+
logger.debug(f"Filename '{filename}' does not contain a path. Attempting to resolve.")
263257
# Attempt to resolve the filename to a full path using referenced_files
264258
if filename in referenced_files:
265259
if len(referenced_files[filename]) > 1:
@@ -488,9 +482,7 @@ def parse_single_trace_content(trace_content: str) -> str:
488482
create_bidirectional_mapping(
489483
ir_maps[src_type], ir_maps[tgt_type], src_type, tgt_type
490484
)
491-
logger.debug(
492-
f"Created bidirectional mapping between {src_type} and {tgt_type}"
493-
)
485+
logger.debug(f"Created bidirectional mapping between {src_type} and {tgt_type}")
494486

495487
py_map = {}
496488

@@ -602,9 +594,7 @@ def parse_single_file(
602594
if frame_id is not None or frame_compile_id is not None:
603595
output_file_name = f"f{frame_id}_fc{frame_compile_id}_a{attempt_id}_cai{compiled_autograd_id}.ndjson"
604596
else:
605-
logger.debug(
606-
"No frame_id or frame_compile_id found in the payload."
607-
)
597+
logger.debug("No frame_id or frame_compile_id found in the payload.")
608598
output_file_name = f"{file_name_without_extension}_mapped.ndjson"
609599
else:
610600
output_file_name = f"{file_name_without_extension}_mapped.ndjson"

tritonparse/shared_vars.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,5 @@
33

44
# The compilation information will be stored to /logs/DEFAULT_TRACE_FILE_PREFIX by default
55
# unless other flags disable or set another store. Add USER to avoid permission issues in shared servers.
6-
DEFAULT_TRACE_FILE_PREFIX = (
7-
f"dedicated_log_triton_trace_{os.getenv('USER', 'unknown')}_"
8-
)
6+
DEFAULT_TRACE_FILE_PREFIX = f"dedicated_log_triton_trace_{os.getenv('USER', 'unknown')}_"
97
DEFAULT_TRACE_FILE_PREFIX_WITHOUT_USER = "dedicated_log_triton_trace_"

tritonparse/source_type.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,4 @@ def _parse_source(self) -> Tuple[SourceType, str]:
5151
elif path.is_file():
5252
return SourceType.LOCAL_FILE, str(path.absolute())
5353
else:
54-
raise ValueError(
55-
f"Source '{self.source_str}' is not a valid directory or file"
56-
)
54+
raise ValueError(f"Source '{self.source_str}' is not a valid directory or file")

tritonparse/structured_logging.py

Lines changed: 12 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,7 @@ def convert(obj):
173173
return str(obj)
174174

175175
if is_dataclass(obj):
176-
return convert(
177-
asdict(obj)
178-
) # Convert dataclass to dict and then process that dict
176+
return convert(asdict(obj)) # Convert dataclass to dict and then process that dict
179177
log.warning(f"Unknown type: {type(obj)}")
180178
return str(obj) # Return primitive types as-is
181179

@@ -193,8 +191,7 @@ def maybe_enable_debug_logging():
193191

194192
# Check if we already have a debug handler
195193
has_debug_handler = any(
196-
isinstance(handler, logging.StreamHandler)
197-
and handler.level <= logging.DEBUG
194+
isinstance(handler, logging.StreamHandler) and handler.level <= logging.DEBUG
198195
for handler in log.handlers
199196
)
200197

@@ -276,11 +273,7 @@ def extract_kernel_name(src) -> Optional[str]:
276273
return src.getattr("name", None)
277274
else:
278275
# For ASTSource, get the function name
279-
if (
280-
hasattr(src, "fn")
281-
and hasattr(src.fn, "fn")
282-
and hasattr(src.fn.fn, "__name__")
283-
):
276+
if hasattr(src, "fn") and hasattr(src.fn, "fn") and hasattr(src.fn.fn, "__name__"):
284277
return src.fn.fn.__name__
285278
return None
286279
except Exception as e:
@@ -316,9 +309,7 @@ def should_trace_kernel(
316309
log.debug(f"Kernel '{kernel_name}' matches pattern '{pattern}', will trace")
317310
return True
318311

319-
log.debug(
320-
f"Kernel '{kernel_name}' does not match any allowlist pattern, skipping trace"
321-
)
312+
log.debug(f"Kernel '{kernel_name}' does not match any allowlist pattern, skipping trace")
322313
return False
323314

324315

@@ -371,18 +362,14 @@ def extract_file_content(trace_data: Dict[str, Any], metadata_group: Dict[str, s
371362
# Check file size before reading to avoid memory issues
372363
file_size = os.path.getsize(file_path)
373364
if file_size > MAX_FILE_SIZE:
374-
trace_data["file_content"][
375-
ir_filename
376-
] = f"<file too large: {file_size} bytes>"
365+
trace_data["file_content"][ir_filename] = f"<file too large: {file_size} bytes>"
377366
continue
378367

379368
with open(file_path, "r") as f:
380369
trace_data["file_content"][ir_filename] = f.read()
381370
except (UnicodeDecodeError, OSError) as e:
382371
# add more specific error type
383-
trace_data["file_content"][
384-
ir_filename
385-
] = f"<error reading file: {str(e)}>"
372+
trace_data["file_content"][ir_filename] = f"<error reading file: {str(e)}>"
386373
log.debug(f"Error reading file {file_path}: {e}")
387374

388375

@@ -439,9 +426,7 @@ class TritonTraceHandler(logging.StreamHandler):
439426
it automatically adds rank information to filenames.
440427
"""
441428

442-
def __init__(
443-
self, root_dir: Optional[str] = None, prefix=DEFAULT_TRACE_FILE_PREFIX
444-
):
429+
def __init__(self, root_dir: Optional[str] = None, prefix=DEFAULT_TRACE_FILE_PREFIX):
445430
logging.Handler.__init__(self)
446431
self.root_dir = root_dir
447432
self.prefix = prefix
@@ -462,13 +447,8 @@ def get_root_dir(self):
462447
if TORCH_INSTALLED:
463448
import torch.version as torch_version
464449

465-
if (
466-
hasattr(torch_version, "git_version")
467-
and os.getenv("MAST_HPC_JOB_NAME") is None
468-
):
469-
log.info(
470-
"TritonTraceHandler: disabled because not fbcode or conda on mast"
471-
)
450+
if hasattr(torch_version, "git_version") and os.getenv("MAST_HPC_JOB_NAME") is None:
451+
log.info("TritonTraceHandler: disabled because not fbcode or conda on mast")
472452
should_set_root_dir = False
473453
# TODO: change to tritonparse knob
474454
elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
@@ -711,9 +691,7 @@ def maybe_trace_triton(
711691
# Add cache_hit to metadata
712692
trace_data["metadata"]["cache_hit"] = cache_hit
713693
if not metadata:
714-
metadata_path = next(
715-
(Path(p) for c, p in metadata_group.items() if c.endswith(".json"))
716-
)
694+
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))
717695
with open(metadata_path, "r") as f:
718696
metadata = json.load(f)
719697
trace_data["metadata"].update(metadata)
@@ -919,9 +897,7 @@ def __call__(self, metadata):
919897
trace_data["name"] = metadata_dict["name"]
920898
trace_data["function"] = metadata_dict["function"]
921899
trace_data["stream"] = metadata_dict["stream"]
922-
launch_metadata_tritonparse = metadata_dict.get(
923-
"launch_metadata_tritonparse", None
924-
)
900+
launch_metadata_tritonparse = metadata_dict.get("launch_metadata_tritonparse", None)
925901
if launch_metadata_tritonparse is not None:
926902
trace_data["grid"] = launch_metadata_tritonparse[0]
927903
trace_data["metadata"] = launch_metadata_tritonparse[1]
@@ -967,9 +943,7 @@ def init_basic(trace_folder: Optional[str] = None):
967943
# Parse and store kernel allowlist configuration
968944
_KERNEL_ALLOWLIST_PATTERNS = parse_kernel_allowlist()
969945
if _KERNEL_ALLOWLIST_PATTERNS:
970-
log.debug(
971-
f"Kernel allowlist enabled with patterns: {_KERNEL_ALLOWLIST_PATTERNS}"
972-
)
946+
log.debug(f"Kernel allowlist enabled with patterns: {_KERNEL_ALLOWLIST_PATTERNS}")
973947
else:
974948
log.debug("Kernel allowlist not set, tracing all kernels")
975949

tritonparse/tools/decompress_bin_ndjson.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,7 @@ def decompress_bin_ndjson(input_file: str, output_file: str = None) -> None:
6262
# Get file sizes for comparison
6363
input_size = input_path.stat().st_size
6464
output_size = output_path.stat().st_size
65-
compression_ratio = (
66-
(1 - input_size / output_size) * 100 if output_size > 0 else 0
67-
)
65+
compression_ratio = (1 - input_size / output_size) * 100 if output_size > 0 else 0
6866

6967
print(f"Successfully decompressed '{input_file}' to '{output_file}'")
7068
print(f" Input size: {input_size:,} bytes")
@@ -100,9 +98,7 @@ def main():
10098
help="Output .ndjson file path (default: replace .bin.ndjson with .ndjson)",
10199
)
102100

103-
parser.add_argument(
104-
"-v", "--verbose", action="store_true", help="Enable verbose output"
105-
)
101+
parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose output")
106102

107103
args = parser.parse_args()
108104

tritonparse/tools/prettify_ndjson.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,7 @@ def parse_line_ranges(lines_arg: str) -> set[int]:
9898
return line_numbers
9999

100100

101-
def load_ndjson(
102-
file_path: Path, save_irs: bool = False, line_filter: set[int] = None
103-
) -> List[Any]:
101+
def load_ndjson(file_path: Path, save_irs: bool = False, line_filter: set[int] = None) -> List[Any]:
104102
"""
105103
Load NDJSON file and return list of JSON objects.
106104
@@ -157,17 +155,13 @@ def load_ndjson(
157155
) # Create a copy to avoid modifying original
158156
for field in fields_to_remove:
159157
del payload[field]
160-
json_obj = (
161-
json_obj.copy()
162-
) # Create a copy of the main object
158+
json_obj = json_obj.copy() # Create a copy of the main object
163159
json_obj["payload"] = payload
164160
filtered_compilation_events += 1
165161

166162
json_objects.append(json_obj)
167163
except json.JSONDecodeError as e:
168-
print(
169-
f"Error parsing JSON on line {line_num}: {e}", file=sys.stderr
170-
)
164+
print(f"Error parsing JSON on line {line_num}: {e}", file=sys.stderr)
171165
print(f"Problematic line: {line[:100]}...", file=sys.stderr)
172166
raise
173167

@@ -230,9 +224,7 @@ def main():
230224
""",
231225
)
232226

233-
parser.add_argument(
234-
"ndjson_file", type=str, help="Path to the NDJSON file to convert"
235-
)
227+
parser.add_argument("ndjson_file", type=str, help="Path to the NDJSON file to convert")
236228

237229
parser.add_argument(
238230
"--save-irs",
@@ -282,9 +274,7 @@ def main():
282274
if args.lines:
283275
try:
284276
line_filter = parse_line_ranges(args.lines)
285-
print(
286-
f"Line filtering enabled: will process {len(line_filter)} specified lines"
287-
)
277+
print(f"Line filtering enabled: will process {len(line_filter)} specified lines")
288278
except ValueError as e:
289279
print(f"Error parsing --lines argument: {e}", file=sys.stderr)
290280
sys.exit(1)
@@ -295,9 +285,7 @@ def main():
295285
print(
296286
"Filtering out file_content and python_source from compilation events to reduce size"
297287
)
298-
json_objects = load_ndjson(
299-
input_path, save_irs=args.save_irs, line_filter=line_filter
300-
)
288+
json_objects = load_ndjson(input_path, save_irs=args.save_irs, line_filter=line_filter)
301289
print(f"Loaded {len(json_objects)} JSON objects")
302290

303291
# Save as prettified JSON

0 commit comments

Comments
 (0)