Skip to content

PyTorch/XLA 2.7 release

Latest
Compare
Choose a tag to compare
@zpcore zpcore released this 24 Apr 00:17
d5871fc

Highlights

  • Easier training on Cloud TPUs with TorchPrime
  • A new Pallas-based kernel for ragged paged attention, enabling further optimizations on vLLM TPU (#8791)
  • Usability improvements
  • Experimental JAX interoperability with JAX operations (#8781, #8789, #8830, #8878)
  • re-enabled GPU CI build [#8593]

Stable Features

  • Operator Lowering
    • Lower as_strided_copy to use fast path with slice (#8374)
    • Lower _conj_copy. (#8686)
  • Support splitting physical axis in SPMD mesh (#8698)
  • Support of placeholder tensor (#8785).
  • Dynamo/AOTAutograd traceable flash attention(#8654)
  • C++11 ABI build is the default

Experimental Features

  • Gated Recurrent Unit (GRU) implemented with scan (#8777)
  • Introduce apply_xla_patch_to_nn_linear to improve einsum performance (#8793)
  • Enable default buffer donation for step barriers (#8721, #8982)

Usability

  • Better profiling control: the start and the end of the profiling session can be controlled by the new profiler API (#8743)
  • API to query number of cached compilation graphs (#8822)
  • Enhancement on host-to-device transfer (#8849)

Bug fixes

  • fix a bug in tensor.flatten (#8680)
  • cummax: fix 0-sized dimension reduction. (#8653)
  • Fix dk/dv autograd error on TPU flash attention (#8685)
  • Fix a bug in flash attention where kv_seq_len should divide block_k_major. (#8671)
  • [scan] Make sure inputs into fn are not device_data IR nodes(#8769)

Libtpu stable version

  • Pin 2.7 release to stable libtpu version '0.0.11.1'

Deprecations

  • Deprecate torch.export and instead, use torchax to export graph to StableHLO for full dynamism support
  • Remove torch_xla.core.xla_model.xrt_world_size, replace with torch_xla.runtime.world_size
  • Remove torch_xla.core.xla_model.get_ordinal, replace with torch_xla.runtime.global_ordinal
  • Remove torch_xla.core.xla_model.parse_xla_device, replace with _utils.parse_xla_device
  • Remove torch_xla.experimental.compile, replace with torch_xla.compile