Skip to content

Commit 335cc91

Browse files
authored
Bump tensorflow to ~=2.18.0 (#12916)
* Tensorflow proto script update * Manual stubtest changes * Use Path for arg type
1 parent e92f98c commit 335cc91

24 files changed

+747
-930
lines changed

scripts/sync_protobuf/tensorflow.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,17 +53,20 @@
5353
XLA_IMPORT_PATTERN = re.compile(r"(\[|\s)xla\.")
5454

5555

56+
def move_tree(source: Path, destination: Path) -> None:
57+
"""Move directory and merge if destination already exists.
58+
59+
Can't use shutil.move because it can't merge existing directories."""
60+
print(f"Moving '{source}' to '{destination}'")
61+
shutil.copytree(source, destination, dirs_exist_ok=True)
62+
shutil.rmtree(source)
63+
64+
5665
def post_creation() -> None:
5766
"""Move third-party and fix imports"""
58-
# Can't use shutil.move because it can't merge existing directories.
5967
print()
60-
print(f"Moving '{STUBS_FOLDER}/tsl' to '{STUBS_FOLDER}/tensorflow/tsl'")
61-
shutil.copytree(f"{STUBS_FOLDER}/tsl", f"{STUBS_FOLDER}/tensorflow/tsl", dirs_exist_ok=True)
62-
shutil.rmtree(f"{STUBS_FOLDER}/tsl")
63-
64-
print(f"Moving '{STUBS_FOLDER}/xla' to '{STUBS_FOLDER}/tensorflow/compiler/xla'")
65-
shutil.copytree(f"{STUBS_FOLDER}/xla", f"{STUBS_FOLDER}/tensorflow/compiler/xla", dirs_exist_ok=True)
66-
shutil.rmtree(f"{STUBS_FOLDER}/xla")
68+
move_tree(STUBS_FOLDER / "tsl", STUBS_FOLDER / "tensorflow" / "tsl")
69+
move_tree(STUBS_FOLDER / "xla", STUBS_FOLDER / "tensorflow" / "compiler" / "xla")
6770

6871
for path in STUBS_FOLDER.rglob("*_pb2.pyi"):
6972
print(f"Fixing imports in '{path}'")
@@ -106,6 +109,7 @@ def main() -> None:
106109
proto_globs=(
107110
f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/*.proto",
108111
f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/service/*.proto",
112+
f"{EXTRACTED_PACKAGE_DIR}/third_party/xla/xla/tsl/protobuf/*.proto",
109113
f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/example/*.proto",
110114
f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/framework/*.proto",
111115
f"{EXTRACTED_PACKAGE_DIR}/tensorflow/core/protobuf/*.proto",

stubs/protobuf/METADATA.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Using an exact number in the specifier for scripts/sync_proto/google_protobuf.py
1+
# Using an exact number in the specifier for scripts/sync_protobuf/google_protobuf.py
22
version = "~=5.28.3"
33
upstream_repository = "https://github.com/protocolbuffers/protobuf"
44
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 27.2 on [protobuf v28.3](https://github.com/protocolbuffers/protobuf/releases/tag/v28.3) (python `protobuf==5.28.3`)."

stubs/s2clientprotocol/METADATA.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Whenever you update version here, PACKAGE_VERSION should be updated
2-
# in scripts/sync_proto/s2clientprotocol.py and vice-versa.
2+
# in scripts/sync_protobuf/s2clientprotocol.py and vice-versa.
33
version = "5.*"
44
upstream_repository = "https://github.com/Blizzard/s2client-proto"
55
requires = ["types-protobuf"]

stubs/tensorflow/METADATA.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
# Using an exact number in the specifier for scripts/sync_proto/tensorflow.py
2-
version = "~=2.17.1"
1+
# Using an exact number in the specifier for scripts/sync_protobuf/tensorflow.py
2+
version = "~=2.18.0"
33
upstream_repository = "https://github.com/tensorflow/tensorflow"
44
# requires a version of numpy with a `py.typed` file
55
# see https://github.com/python/typeshed/issues/12551
66
# on why we need the upper bound for numpy
77
requires = ["numpy>=1.20,<2.1.0", "types-protobuf", "types-requests"]
8-
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 27.2 on `tensorflow==2.17.1`."
8+
extra_description = "Partially generated using [mypy-protobuf==3.6.0](https://github.com/nipunn1313/mypy-protobuf/tree/v3.6.0) and libprotoc 27.2 on `tensorflow==2.18.0`."
99
partial_stub = true
1010

1111
[tool.stubtest]

stubs/tensorflow/tensorflow/compiler/xla/service/hlo_pb2.pyi

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ global___Kind = Kind
219219
@typing.final
220220
class HloInstructionProto(google.protobuf.message.Message):
221221
"""Serialization of HloInstruction.
222-
Next ID: 87
222+
Next ID: 90
223223
"""
224224

225225
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -316,6 +316,9 @@ class HloInstructionProto(google.protobuf.message.Message):
316316
LARGEST_FIELD_NUMBER: builtins.int
317317
STATISTICS_VIZ_FIELD_NUMBER: builtins.int
318318
DOT_SPARSITY_FIELD_NUMBER: builtins.int
319+
COLLECTIVE_DEVICE_LIST_FIELD_NUMBER: builtins.int
320+
ORIGINAL_VALUE_FIELD_NUMBER: builtins.int
321+
IS_COMPOSITE_FIELD_NUMBER: builtins.int
319322
name: builtins.str
320323
opcode: builtins.str
321324
parameter_number: builtins.int
@@ -433,6 +436,8 @@ class HloInstructionProto(google.protobuf.message.Message):
433436
"""Represents the K value for top-k."""
434437
largest: builtins.bool
435438
"""Represents the largest flag for top-k."""
439+
is_composite: builtins.bool
440+
"""Specifies if a call instruction is a composite."""
436441
@property
437442
def shape(self) -> tensorflow.compiler.xla.xla_data_pb2.ShapeProto: ...
438443
@property
@@ -497,7 +502,9 @@ class HloInstructionProto(google.protobuf.message.Message):
497502
def sharding(self) -> tensorflow.compiler.xla.xla_data_pb2.OpSharding: ...
498503
@property
499504
def replica_groups(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.ReplicaGroup]:
500-
"""Cross replica op fields."""
505+
"""Deprecated, but keeping for backward compatibility.
506+
Use collective_device_list. Cross replica op fields.
507+
"""
501508

502509
@property
503510
def scatter_dimension_numbers(self) -> tensorflow.compiler.xla.xla_data_pb2.ScatterDimensionNumbers: ...
@@ -549,6 +556,14 @@ class HloInstructionProto(google.protobuf.message.Message):
549556
def dot_sparsity(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor]:
550557
"""Sparsity descriptor for dot operation."""
551558

559+
@property
560+
def collective_device_list(self) -> tensorflow.compiler.xla.xla_data_pb2.CollectiveDeviceListProto:
561+
"""Represents the list of devices that participate in a collective operation."""
562+
563+
@property
564+
def original_value(self) -> tensorflow.compiler.xla.xla_data_pb2.OriginalValueProto:
565+
"""For HLO value tracking."""
566+
552567
def __init__(
553568
self,
554569
*,
@@ -623,9 +638,12 @@ class HloInstructionProto(google.protobuf.message.Message):
623638
largest: builtins.bool | None = ...,
624639
statistics_viz: tensorflow.compiler.xla.xla_data_pb2.StatisticsViz | None = ...,
625640
dot_sparsity: collections.abc.Iterable[tensorflow.compiler.xla.xla_data_pb2.SparsityDescriptor] | None = ...,
641+
collective_device_list: tensorflow.compiler.xla.xla_data_pb2.CollectiveDeviceListProto | None = ...,
642+
original_value: tensorflow.compiler.xla.xla_data_pb2.OriginalValueProto | None = ...,
643+
is_composite: builtins.bool | None = ...,
626644
) -> None: ...
627-
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
628-
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
645+
def HasField(self, field_name: typing.Literal["cholesky_options", b"cholesky_options", "collective_device_list", b"collective_device_list", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "frontend_attributes", b"frontend_attributes", "gather_dimension_numbers", b"gather_dimension_numbers", "literal", b"literal", "metadata", b"metadata", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "original_value", b"original_value", "outfeed_shape", b"outfeed_shape", "padding_config", b"padding_config", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "window", b"window"]) -> builtins.bool: ...
646+
def ClearField(self, field_name: typing.Literal["all_reduce_id", b"all_reduce_id", "async_execution_thread", b"async_execution_thread", "backend_config", b"backend_config", "batch_group_count", b"batch_group_count", "called_computation_ids", b"called_computation_ids", "channel_id", b"channel_id", "cholesky_options", b"cholesky_options", "collective_device_list", b"collective_device_list", "comparison_direction", b"comparison_direction", "comparison_type", b"comparison_type", "constrain_layout", b"constrain_layout", "control_predecessor_ids", b"control_predecessor_ids", "convolution_dimension_numbers", b"convolution_dimension_numbers", "cross_program_prefetch_index", b"cross_program_prefetch_index", "custom_call_api_version", b"custom_call_api_version", "custom_call_has_side_effect", b"custom_call_has_side_effect", "custom_call_schedule", b"custom_call_schedule", "custom_call_target", b"custom_call_target", "delta", b"delta", "dimensions", b"dimensions", "distribution", b"distribution", "domain_entry_sharding", b"domain_entry_sharding", "domain_exit_sharding", b"domain_exit_sharding", "dot_dimension_numbers", b"dot_dimension_numbers", "dot_sparsity", b"dot_sparsity", "dynamic_slice_sizes", b"dynamic_slice_sizes", "epsilon", b"epsilon", "exponent_bits", b"exponent_bits", "feature_group_count", b"feature_group_count", "feature_index", b"feature_index", "fft_length", b"fft_length", "fft_type", b"fft_type", "frontend_attributes", b"frontend_attributes", "fusion_kind", b"fusion_kind", "gather_dimension_numbers", b"gather_dimension_numbers", "gather_slice_sizes", b"gather_slice_sizes", "id", b"id", "indices_are_sorted", b"indices_are_sorted", "infeed_config", b"infeed_config", "is_composite", b"is_composite", "is_cross_program_prefetch", b"is_cross_program_prefetch", "is_host_transfer", b"is_host_transfer", "is_stable", b"is_stable", "k", b"k", "largest", b"largest", "literal", b"literal", "mantissa_bits", b"mantissa_bits", "metadata", b"metadata", "name", b"name", "opcode", b"opcode", "operand_ids", b"operand_ids", "operand_shapes_with_layout", b"operand_shapes_with_layout", "optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index", "original_value", b"original_value", "outfeed_config", b"outfeed_config", "outfeed_shape", b"outfeed_shape", "output_operand_aliasing", b"output_operand_aliasing", "padding_config", b"padding_config", "padding_type", b"padding_type", "parameter_number", b"parameter_number", "parameter_replication", b"parameter_replication", "precision_config", b"precision_config", "replica_groups", b"replica_groups", "rng_algorithm", b"rng_algorithm", "scatter_dimension_numbers", b"scatter_dimension_numbers", "shape", b"shape", "sharding", b"sharding", "slice_dimensions", b"slice_dimensions", "source_target_pairs", b"source_target_pairs", "statistics_viz", b"statistics_viz", "triangular_solve_options", b"triangular_solve_options", "tuple_index", b"tuple_index", "unique_indices", b"unique_indices", "use_global_device_ids", b"use_global_device_ids", "window", b"window"]) -> None: ...
629647
def WhichOneof(self, oneof_group: typing.Literal["optional_cross_program_prefetch_index", b"optional_cross_program_prefetch_index"]) -> typing.Literal["cross_program_prefetch_index"] | None: ...
630648

631649
global___HloInstructionProto = HloInstructionProto
@@ -980,6 +998,7 @@ class HloModuleProto(google.protobuf.message.Message):
980998
FUSION: HloModuleProto._ProfileType.ValueType # 2
981999
LAYOUT: HloModuleProto._ProfileType.ValueType # 3
9821000
DOT: HloModuleProto._ProfileType.ValueType # 4
1001+
FLAGNET: HloModuleProto._ProfileType.ValueType # 5
9831002

9841003
class ProfileType(_ProfileType, metaclass=_ProfileTypeEnumTypeWrapper):
9851004
"""The type of optimization profile in use for module-level optimizations."""
@@ -989,6 +1008,7 @@ class HloModuleProto(google.protobuf.message.Message):
9891008
FUSION: HloModuleProto.ProfileType.ValueType # 2
9901009
LAYOUT: HloModuleProto.ProfileType.ValueType # 3
9911010
DOT: HloModuleProto.ProfileType.ValueType # 4
1011+
FLAGNET: HloModuleProto.ProfileType.ValueType # 5
9921012

9931013
@typing.final
9941014
class ProfileInfo(google.protobuf.message.Message):
@@ -1604,35 +1624,3 @@ class HloPassMetadata(google.protobuf.message.Message):
16041624
def ClearField(self, field_name: typing.Literal["custom_metadata", b"custom_metadata", "dump_filenames", b"dump_filenames", "end_timestamp_usec", b"end_timestamp_usec", "module_changed", b"module_changed", "module_group_module_ids", b"module_group_module_ids", "module_id", b"module_id", "pass_id", b"pass_id", "pass_name", b"pass_name", "pipeline_name", b"pipeline_name", "start_timestamp_usec", b"start_timestamp_usec"]) -> None: ...
16051625

16061626
global___HloPassMetadata = HloPassMetadata
1607-
1608-
@typing.final
1609-
class XlaRuntimeExecutableProto(google.protobuf.message.Message):
1610-
"""Encodes the underlying Xla runtime executable compiled from the XLA module."""
1611-
1612-
DESCRIPTOR: google.protobuf.descriptor.Descriptor
1613-
1614-
HLO_MODULE_PROTO_FIELD_NUMBER: builtins.int
1615-
OBJ_FILE_FIELD_NUMBER: builtins.int
1616-
MLIR_MODULE_FIELD_NUMBER: builtins.int
1617-
obj_file: builtins.bytes
1618-
"""TODO(b/232263665)): Serialized executable has to know what APIs it has to
1619-
be linked with, including the version. For example Gpu executable must be
1620-
linked with a runtime layer that abstracts over CUDA.
1621-
1622-
Serialized object file compiled from the XLA module.
1623-
"""
1624-
mlir_module: builtins.str
1625-
"""Serialized MLIR module corresponding to compiled object file."""
1626-
@property
1627-
def hlo_module_proto(self) -> global___HloModuleProto: ...
1628-
def __init__(
1629-
self,
1630-
*,
1631-
hlo_module_proto: global___HloModuleProto | None = ...,
1632-
obj_file: builtins.bytes | None = ...,
1633-
mlir_module: builtins.str | None = ...,
1634-
) -> None: ...
1635-
def HasField(self, field_name: typing.Literal["hlo_module_proto", b"hlo_module_proto"]) -> builtins.bool: ...
1636-
def ClearField(self, field_name: typing.Literal["hlo_module_proto", b"hlo_module_proto", "mlir_module", b"mlir_module", "obj_file", b"obj_file"]) -> None: ...
1637-
1638-
global___XlaRuntimeExecutableProto = XlaRuntimeExecutableProto

0 commit comments

Comments
 (0)