diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 37adaef7979..51781ee91b7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -49,3 +49,17 @@ repos: entry: autoflake --in-place --remove-unused-variables --remove-all-unused-imports language: system types: [python] + + - id: check-future-annotations + name: Check for `from __future__ import annotations` + entry: python packaging/check_future_annotations.py + language: python + files: \.py$ + exclude: run-clang-format\.py + + - id: check-header + name: Check for Meta copyright header + entry: python packaging/check_headers.py + language: python + files: \.py$ + exclude: run-clang-format\.py diff --git a/packaging/check_future_annotations.py b/packaging/check_future_annotations.py new file mode 100644 index 00000000000..4776bf20bc2 --- /dev/null +++ b/packaging/check_future_annotations.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import sys + + +def add_future_import(filename): + with open(filename, encoding="utf-8") as f: + lines = f.readlines() + + # Check if the import is already present + for line in lines: + if line.strip() == "from __future__ import annotations": + return # Import already present, no need to modify + + # Find the position to insert the import + insert_pos = 0 + for i, line in enumerate(lines): + stripped_line = line.strip() + if stripped_line and not stripped_line.startswith("#"): + insert_pos = i + break + + # Insert the import statement after the first comment block + lines.insert(insert_pos, "from __future__ import annotations\n\n") + + # Write the modified lines back to the file + with open(filename, "w", encoding="utf-8") as f: + f.writelines(lines) + + +def main(): + files = sys.argv[1:] + for f in files: + add_future_import(f) + print("Processed files to ensure `from __future__ import annotations` is present.") + + +if __name__ == "__main__": + main() diff --git a/packaging/check_headers.py b/packaging/check_headers.py new file mode 100644 index 00000000000..5620122da0c --- /dev/null +++ b/packaging/check_headers.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import sys + +HEADER = """# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" + + +def check_header(filename): + with open(filename, encoding="utf-8") as f: + file_content = f.read() + + if not file_content.startswith(HEADER): + print(f"Missing or incorrect header in {filename}") + return False + return True + + +def main(): + files = sys.argv[1:] + all_passed = True + for f in files: + if not check_header(f): + all_passed = False + if not all_passed: + sys.exit(1) + sys.exit(0) + + +if __name__ == "__main__": + main()