Skip to content

Commit 8dc5b49

Browse files
authored
Avoid unnecessary copy in TensorSource (#8849)
1 parent 3066b03 commit 8dc5b49

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

torch_xla/csrc/runtime/tensor_source.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,12 @@ class AtenSource : public TensorSource {
5757
// TODO(ysiraichi): check, first, if tensor lives in a device that the
5858
// current PjRt client has access. If so, we don't need to go through the
5959
// CPU.
60+
// Set `copy` to false becuase torch can figure out if it needs to copy the
61+
// data or not.
6062
tensor_ = std::move(
6163
tensor.to(at::TensorOptions().device(at::kCPU).dtype(target_torch_type),
6264
/*non_blocking=*/false,
63-
/*copy=*/true, at::MemoryFormat::Contiguous));
65+
/*copy=*/false, at::MemoryFormat::Contiguous));
6466
}
6567

6668
const void* data() const override { return tensor_.const_data_ptr(); }

0 commit comments

Comments
 (0)