@@ -219,7 +219,7 @@ global___Kind = Kind
219
219
@typing .final
220
220
class HloInstructionProto (google .protobuf .message .Message ):
221
221
"""Serialization of HloInstruction.
222
- Next ID: 87
222
+ Next ID: 90
223
223
"""
224
224
225
225
DESCRIPTOR : google .protobuf .descriptor .Descriptor
@@ -316,6 +316,9 @@ class HloInstructionProto(google.protobuf.message.Message):
316
316
LARGEST_FIELD_NUMBER : builtins .int
317
317
STATISTICS_VIZ_FIELD_NUMBER : builtins .int
318
318
DOT_SPARSITY_FIELD_NUMBER : builtins .int
319
+ COLLECTIVE_DEVICE_LIST_FIELD_NUMBER : builtins .int
320
+ ORIGINAL_VALUE_FIELD_NUMBER : builtins .int
321
+ IS_COMPOSITE_FIELD_NUMBER : builtins .int
319
322
name : builtins .str
320
323
opcode : builtins .str
321
324
parameter_number : builtins .int
@@ -433,6 +436,8 @@ class HloInstructionProto(google.protobuf.message.Message):
433
436
"""Represents the K value for top-k."""
434
437
largest : builtins .bool
435
438
"""Represents the largest flag for top-k."""
439
+ is_composite : builtins .bool
440
+ """Specifies if a call instruction is a composite."""
436
441
@property
437
442
def shape (self ) -> tensorflow .compiler .xla .xla_data_pb2 .ShapeProto : ...
438
443
@property
@@ -497,7 +502,9 @@ class HloInstructionProto(google.protobuf.message.Message):
497
502
def sharding (self ) -> tensorflow .compiler .xla .xla_data_pb2 .OpSharding : ...
498
503
@property
499
504
def replica_groups (self ) -> google .protobuf .internal .containers .RepeatedCompositeFieldContainer [tensorflow .compiler .xla .xla_data_pb2 .ReplicaGroup ]:
500
- """Cross replica op fields."""
505
+ """Deprecated, but keeping for backward compatibility.
506
+ Use collective_device_list. Cross replica op fields.
507
+ """
501
508
502
509
@property
503
510
def scatter_dimension_numbers (self ) -> tensorflow .compiler .xla .xla_data_pb2 .ScatterDimensionNumbers : ...
@@ -549,6 +556,14 @@ class HloInstructionProto(google.protobuf.message.Message):
549
556
def dot_sparsity (self ) -> google .protobuf .internal .containers .RepeatedCompositeFieldContainer [tensorflow .compiler .xla .xla_data_pb2 .SparsityDescriptor ]:
550
557
"""Sparsity descriptor for dot operation."""
551
558
559
+ @property
560
+ def collective_device_list (self ) -> tensorflow .compiler .xla .xla_data_pb2 .CollectiveDeviceListProto :
561
+ """Represents the list of devices that participate in a collective operation."""
562
+
563
+ @property
564
+ def original_value (self ) -> tensorflow .compiler .xla .xla_data_pb2 .OriginalValueProto :
565
+ """For HLO value tracking."""
566
+
552
567
def __init__ (
553
568
self ,
554
569
* ,
@@ -623,9 +638,12 @@ class HloInstructionProto(google.protobuf.message.Message):
623
638
largest : builtins .bool | None = ...,
624
639
statistics_viz : tensorflow .compiler .xla .xla_data_pb2 .StatisticsViz | None = ...,
625
640
dot_sparsity : collections .abc .Iterable [tensorflow .compiler .xla .xla_data_pb2 .SparsityDescriptor ] | None = ...,
641
+ collective_device_list : tensorflow .compiler .xla .xla_data_pb2 .CollectiveDeviceListProto | None = ...,
642
+ original_value : tensorflow .compiler .xla .xla_data_pb2 .OriginalValueProto | None = ...,
643
+ is_composite : builtins .bool | None = ...,
626
644
) -> None : ...
627
- def HasField (self , field_name : typing .Literal ["cholesky_options" , b"cholesky_options" , "convolution_dimension_numbers" , b"convolution_dimension_numbers" , "cross_program_prefetch_index" , b"cross_program_prefetch_index" , "domain_entry_sharding" , b"domain_entry_sharding" , "domain_exit_sharding" , b"domain_exit_sharding" , "dot_dimension_numbers" , b"dot_dimension_numbers" , "frontend_attributes" , b"frontend_attributes" , "gather_dimension_numbers" , b"gather_dimension_numbers" , "literal" , b"literal" , "metadata" , b"metadata" , "optional_cross_program_prefetch_index" , b"optional_cross_program_prefetch_index" , "outfeed_shape" , b"outfeed_shape" , "padding_config" , b"padding_config" , "parameter_replication" , b"parameter_replication" , "precision_config" , b"precision_config" , "scatter_dimension_numbers" , b"scatter_dimension_numbers" , "shape" , b"shape" , "sharding" , b"sharding" , "statistics_viz" , b"statistics_viz" , "triangular_solve_options" , b"triangular_solve_options" , "window" , b"window" ]) -> builtins .bool : ...
628
- def ClearField (self , field_name : typing .Literal ["all_reduce_id" , b"all_reduce_id" , "async_execution_thread" , b"async_execution_thread" , "backend_config" , b"backend_config" , "batch_group_count" , b"batch_group_count" , "called_computation_ids" , b"called_computation_ids" , "channel_id" , b"channel_id" , "cholesky_options" , b"cholesky_options" , "comparison_direction" , b"comparison_direction" , "comparison_type" , b"comparison_type" , "constrain_layout" , b"constrain_layout" , "control_predecessor_ids" , b"control_predecessor_ids" , "convolution_dimension_numbers" , b"convolution_dimension_numbers" , "cross_program_prefetch_index" , b"cross_program_prefetch_index" , "custom_call_api_version" , b"custom_call_api_version" , "custom_call_has_side_effect" , b"custom_call_has_side_effect" , "custom_call_schedule" , b"custom_call_schedule" , "custom_call_target" , b"custom_call_target" , "delta" , b"delta" , "dimensions" , b"dimensions" , "distribution" , b"distribution" , "domain_entry_sharding" , b"domain_entry_sharding" , "domain_exit_sharding" , b"domain_exit_sharding" , "dot_dimension_numbers" , b"dot_dimension_numbers" , "dot_sparsity" , b"dot_sparsity" , "dynamic_slice_sizes" , b"dynamic_slice_sizes" , "epsilon" , b"epsilon" , "exponent_bits" , b"exponent_bits" , "feature_group_count" , b"feature_group_count" , "feature_index" , b"feature_index" , "fft_length" , b"fft_length" , "fft_type" , b"fft_type" , "frontend_attributes" , b"frontend_attributes" , "fusion_kind" , b"fusion_kind" , "gather_dimension_numbers" , b"gather_dimension_numbers" , "gather_slice_sizes" , b"gather_slice_sizes" , "id" , b"id" , "indices_are_sorted" , b"indices_are_sorted" , "infeed_config" , b"infeed_config" , "is_cross_program_prefetch" , b"is_cross_program_prefetch" , "is_host_transfer" , b"is_host_transfer" , "is_stable" , b"is_stable" , "k" , b"k" , "largest" , b"largest" , "literal" , b"literal" , "mantissa_bits" , b"mantissa_bits" , "metadata" , b"metadata" , "name" , b"name" , "opcode" , b"opcode" , "operand_ids" , b"operand_ids" , "operand_shapes_with_layout" , b"operand_shapes_with_layout" , "optional_cross_program_prefetch_index" , b"optional_cross_program_prefetch_index" , "outfeed_config" , b"outfeed_config" , "outfeed_shape" , b"outfeed_shape" , "output_operand_aliasing" , b"output_operand_aliasing" , "padding_config" , b"padding_config" , "padding_type" , b"padding_type" , "parameter_number" , b"parameter_number" , "parameter_replication" , b"parameter_replication" , "precision_config" , b"precision_config" , "replica_groups" , b"replica_groups" , "rng_algorithm" , b"rng_algorithm" , "scatter_dimension_numbers" , b"scatter_dimension_numbers" , "shape" , b"shape" , "sharding" , b"sharding" , "slice_dimensions" , b"slice_dimensions" , "source_target_pairs" , b"source_target_pairs" , "statistics_viz" , b"statistics_viz" , "triangular_solve_options" , b"triangular_solve_options" , "tuple_index" , b"tuple_index" , "unique_indices" , b"unique_indices" , "use_global_device_ids" , b"use_global_device_ids" , "window" , b"window" ]) -> None : ...
645
+ def HasField (self , field_name : typing .Literal ["cholesky_options" , b"cholesky_options" , "collective_device_list" , b"collective_device_list" , "convolution_dimension_numbers" , b"convolution_dimension_numbers" , "cross_program_prefetch_index" , b"cross_program_prefetch_index" , "domain_entry_sharding" , b"domain_entry_sharding" , "domain_exit_sharding" , b"domain_exit_sharding" , "dot_dimension_numbers" , b"dot_dimension_numbers" , "frontend_attributes" , b"frontend_attributes" , "gather_dimension_numbers" , b"gather_dimension_numbers" , "literal" , b"literal" , "metadata" , b"metadata" , "optional_cross_program_prefetch_index" , b"optional_cross_program_prefetch_index" , "original_value" , b"original_value" , "outfeed_shape" , b"outfeed_shape" , "padding_config" , b"padding_config" , "parameter_replication" , b"parameter_replication" , "precision_config" , b"precision_config" , "scatter_dimension_numbers" , b"scatter_dimension_numbers" , "shape" , b"shape" , "sharding" , b"sharding" , "statistics_viz" , b"statistics_viz" , "triangular_solve_options" , b"triangular_solve_options" , "window" , b"window" ]) -> builtins .bool : ...
646
+ def ClearField (self , field_name : typing .Literal ["all_reduce_id" , b"all_reduce_id" , "async_execution_thread" , b"async_execution_thread" , "backend_config" , b"backend_config" , "batch_group_count" , b"batch_group_count" , "called_computation_ids" , b"called_computation_ids" , "channel_id" , b"channel_id" , "cholesky_options" , b"cholesky_options" , "collective_device_list" , b"collective_device_list" , "comparison_direction" , b"comparison_direction" , "comparison_type" , b"comparison_type" , "constrain_layout" , b"constrain_layout" , "control_predecessor_ids" , b"control_predecessor_ids" , "convolution_dimension_numbers" , b"convolution_dimension_numbers" , "cross_program_prefetch_index" , b"cross_program_prefetch_index" , "custom_call_api_version" , b"custom_call_api_version" , "custom_call_has_side_effect" , b"custom_call_has_side_effect" , "custom_call_schedule" , b"custom_call_schedule" , "custom_call_target" , b"custom_call_target" , "delta" , b"delta" , "dimensions" , b"dimensions" , "distribution" , b"distribution" , "domain_entry_sharding" , b"domain_entry_sharding" , "domain_exit_sharding" , b"domain_exit_sharding" , "dot_dimension_numbers" , b"dot_dimension_numbers" , "dot_sparsity" , b"dot_sparsity" , "dynamic_slice_sizes" , b"dynamic_slice_sizes" , "epsilon" , b"epsilon" , "exponent_bits" , b"exponent_bits" , "feature_group_count" , b"feature_group_count" , "feature_index" , b"feature_index" , "fft_length" , b"fft_length" , "fft_type" , b"fft_type" , "frontend_attributes" , b"frontend_attributes" , "fusion_kind" , b"fusion_kind" , "gather_dimension_numbers" , b"gather_dimension_numbers" , "gather_slice_sizes" , b"gather_slice_sizes" , "id" , b"id" , "indices_are_sorted" , b"indices_are_sorted" , "infeed_config" , b"infeed_config" , "is_composite" , b"is_composite" , "is_cross_program_prefetch" , b"is_cross_program_prefetch" , "is_host_transfer" , b"is_host_transfer" , "is_stable" , b"is_stable" , "k" , b"k" , "largest" , b"largest" , "literal" , b"literal" , "mantissa_bits" , b"mantissa_bits" , "metadata" , b"metadata" , "name" , b"name" , "opcode" , b"opcode" , "operand_ids" , b"operand_ids" , "operand_shapes_with_layout" , b"operand_shapes_with_layout" , "optional_cross_program_prefetch_index" , b"optional_cross_program_prefetch_index" , "original_value" , b"original_value" , "outfeed_config" , b"outfeed_config" , "outfeed_shape" , b"outfeed_shape" , "output_operand_aliasing" , b"output_operand_aliasing" , "padding_config" , b"padding_config" , "padding_type" , b"padding_type" , "parameter_number" , b"parameter_number" , "parameter_replication" , b"parameter_replication" , "precision_config" , b"precision_config" , "replica_groups" , b"replica_groups" , "rng_algorithm" , b"rng_algorithm" , "scatter_dimension_numbers" , b"scatter_dimension_numbers" , "shape" , b"shape" , "sharding" , b"sharding" , "slice_dimensions" , b"slice_dimensions" , "source_target_pairs" , b"source_target_pairs" , "statistics_viz" , b"statistics_viz" , "triangular_solve_options" , b"triangular_solve_options" , "tuple_index" , b"tuple_index" , "unique_indices" , b"unique_indices" , "use_global_device_ids" , b"use_global_device_ids" , "window" , b"window" ]) -> None : ...
629
647
def WhichOneof (self , oneof_group : typing .Literal ["optional_cross_program_prefetch_index" , b"optional_cross_program_prefetch_index" ]) -> typing .Literal ["cross_program_prefetch_index" ] | None : ...
630
648
631
649
global___HloInstructionProto = HloInstructionProto
@@ -980,6 +998,7 @@ class HloModuleProto(google.protobuf.message.Message):
980
998
FUSION : HloModuleProto ._ProfileType .ValueType # 2
981
999
LAYOUT : HloModuleProto ._ProfileType .ValueType # 3
982
1000
DOT : HloModuleProto ._ProfileType .ValueType # 4
1001
+ FLAGNET : HloModuleProto ._ProfileType .ValueType # 5
983
1002
984
1003
class ProfileType (_ProfileType , metaclass = _ProfileTypeEnumTypeWrapper ):
985
1004
"""The type of optimization profile in use for module-level optimizations."""
@@ -989,6 +1008,7 @@ class HloModuleProto(google.protobuf.message.Message):
989
1008
FUSION : HloModuleProto .ProfileType .ValueType # 2
990
1009
LAYOUT : HloModuleProto .ProfileType .ValueType # 3
991
1010
DOT : HloModuleProto .ProfileType .ValueType # 4
1011
+ FLAGNET : HloModuleProto .ProfileType .ValueType # 5
992
1012
993
1013
@typing .final
994
1014
class ProfileInfo (google .protobuf .message .Message ):
@@ -1604,35 +1624,3 @@ class HloPassMetadata(google.protobuf.message.Message):
1604
1624
def ClearField (self , field_name : typing .Literal ["custom_metadata" , b"custom_metadata" , "dump_filenames" , b"dump_filenames" , "end_timestamp_usec" , b"end_timestamp_usec" , "module_changed" , b"module_changed" , "module_group_module_ids" , b"module_group_module_ids" , "module_id" , b"module_id" , "pass_id" , b"pass_id" , "pass_name" , b"pass_name" , "pipeline_name" , b"pipeline_name" , "start_timestamp_usec" , b"start_timestamp_usec" ]) -> None : ...
1605
1625
1606
1626
global___HloPassMetadata = HloPassMetadata
1607
-
1608
- @typing .final
1609
- class XlaRuntimeExecutableProto (google .protobuf .message .Message ):
1610
- """Encodes the underlying Xla runtime executable compiled from the XLA module."""
1611
-
1612
- DESCRIPTOR : google .protobuf .descriptor .Descriptor
1613
-
1614
- HLO_MODULE_PROTO_FIELD_NUMBER : builtins .int
1615
- OBJ_FILE_FIELD_NUMBER : builtins .int
1616
- MLIR_MODULE_FIELD_NUMBER : builtins .int
1617
- obj_file : builtins .bytes
1618
- """TODO(b/232263665)): Serialized executable has to know what APIs it has to
1619
- be linked with, including the version. For example Gpu executable must be
1620
- linked with a runtime layer that abstracts over CUDA.
1621
-
1622
- Serialized object file compiled from the XLA module.
1623
- """
1624
- mlir_module : builtins .str
1625
- """Serialized MLIR module corresponding to compiled object file."""
1626
- @property
1627
- def hlo_module_proto (self ) -> global___HloModuleProto : ...
1628
- def __init__ (
1629
- self ,
1630
- * ,
1631
- hlo_module_proto : global___HloModuleProto | None = ...,
1632
- obj_file : builtins .bytes | None = ...,
1633
- mlir_module : builtins .str | None = ...,
1634
- ) -> None : ...
1635
- def HasField (self , field_name : typing .Literal ["hlo_module_proto" , b"hlo_module_proto" ]) -> builtins .bool : ...
1636
- def ClearField (self , field_name : typing .Literal ["hlo_module_proto" , b"hlo_module_proto" , "mlir_module" , b"mlir_module" , "obj_file" , b"obj_file" ]) -> None : ...
1637
-
1638
- global___XlaRuntimeExecutableProto = XlaRuntimeExecutableProto
0 commit comments