Skip to content

Commit e2f2852

Browse files
authored
Default explicit donation for step barriers (#8982)
1 parent 428c86d commit e2f2852

File tree

3 files changed

+61
-56
lines changed

3 files changed

+61
-56
lines changed

test/dynamo/test_dynamo_aliasing.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -185,41 +185,6 @@ def test_buffer_donation_on_non_data_tensor(self):
185185
self.assertNotIn('XlaSetBufferDonation', met.counter_names())
186186

187187

188-
class TestNonDynamoBufferDonationAliasing(unittest.TestCase):
189-
190-
def dummy_fn(self, input):
191-
return torch.cos(torch.sin(input))
192-
193-
# Currently let's skip buffer donation api for the non-dynamo use case
194-
def test_buffer_donation_skip_for_non_dynamo(self):
195-
device = xm.xla_device()
196-
input = torch.randn(5, 5).to(device)
197-
xm.mark_step()
198-
met.clear_all()
199-
200-
# We should be able to set buffer donation for input tensor, but when mark_step
201-
# triggered, the buffer donation should be ignored.
202-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
203-
res = self.dummy_fn(input)
204-
xm.mark_step()
205-
# Make sure that input buffer is not aliased and can be used for other compuations.
206-
# Also make sure that buffer_donation will not trigger recompilation in non-dynamo.
207-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, False))
208-
res2 = self.dummy_fn(input)
209-
xm.mark_step()
210-
torch.allclose(res.cpu(), res2.cpu())
211-
self.assertEqual(met.metric_data('CompileTime')[0], 1)
212-
213-
def test_no_op_mark_step_keep_buffer_donation(self):
214-
device = xm.xla_device()
215-
input = torch.randn(5, 5).to(device)
216-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
217-
xm.mark_step()
218-
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
219-
xm.mark_step()
220-
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
221-
222-
223188
if __name__ == '__main__':
224189
test = unittest.main()
225190
sys.exit(0 if test.result.wasSuccessful() else 1)

test/test_input_output_aliases.py

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1+
import contextlib
2+
import copy
13
import os
24
import sys
5+
import unittest
6+
from absl.testing import parameterized
37

48
import torch
59
import torch.nn as nn
610
import torch.nn.functional as F
711
import torch_xla
812
import torch_xla.core.xla_model as xm
913
import torch_xla.debug.metrics as met
10-
import unittest
11-
import contextlib
12-
import copy
1314

1415

1516
def create_xla_config_context(set_func, get_func):
@@ -34,7 +35,7 @@ def config_context(value):
3435

3536

3637
# TODO(alanwaketan): add test for views.
37-
class InputOutputAliasesTest(unittest.TestCase):
38+
class InputOutputAliasesTest(parameterized.TestCase):
3839

3940
def test_non_view(self):
4041
xla_device = xm.xla_device()
@@ -233,34 +234,59 @@ def test_device_data_cache_no_aliasing(self):
233234
self.assertEqual(t1.item(), 43)
234235

235236
def test_user_config_donation_with_ltc_donation(self):
236-
with alias_with_buffer_donor_config_context(True):
237+
met.clear_all()
238+
xla_device = xm.xla_device()
239+
t0 = torch.randn(4, 2, 2).to(xla_device)
240+
t1 = torch.randn(4, 2, 2).to(xla_device)
241+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
242+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
243+
self.assertFalse(torch_xla._XLAC._get_buffer_donation(t1))
244+
t2 = t0 + t1
245+
t1 += 2
246+
xm.mark_step(wait=True)
247+
248+
# We surface the C++ runtime error by checking that the backend data is
249+
# no longer present for the IR node.
250+
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
251+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)
252+
253+
@parameterized.parameters(True, False)
254+
def test_user_config_donation_with_ltc_donation_graph_sync(
255+
self, enable_buffer_donor_config):
256+
with alias_with_buffer_donor_config_context(enable_buffer_donor_config):
237257
met.clear_all()
238258
xla_device = xm.xla_device()
239259
t0 = torch.randn(4, 2, 2).to(xla_device)
240260
t1 = torch.randn(4, 2, 2).to(xla_device)
241261
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
242262
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
243263
self.assertFalse(torch_xla._XLAC._get_buffer_donation(t1))
244-
t3 = t0 + t1
264+
t2 = t0 + t1
245265
t1 += 2
246-
xm.mark_step(wait=True)
266+
# We use _xla_sync_multi to explicitly disable sync_xla_data, which will
267+
# in turn avoid using LTC aliasings. This ensures that the resulting
268+
# aliasings are due to the buffer donation.
269+
torch_xla._XLAC._xla_sync_multi([t0, t1, t2], [str(xla_device)], True,
270+
False)
247271

248272
# We surface the C++ runtime error by checking that the backend data is
249273
# no longer present for the IR node.
250-
self.assertTrue(torch_xla._XLAC._is_placecholder(t0))
251-
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 2.0)
274+
self.assertEqual(
275+
torch_xla._XLAC._is_placecholder(t0), enable_buffer_donor_config)
276+
self.assertEqual(
277+
met.metric_data("InputOutputAliasCount")[1],
278+
enable_buffer_donor_config)
252279

253280
def test_user_config_donation_with_ltc_donation_overlap(self):
254-
with alias_with_buffer_donor_config_context(True):
255-
met.clear_all()
256-
xla_device = xm.xla_device()
257-
t0 = torch.randn(4, 2, 2).to(xla_device)
258-
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
259-
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
260-
t0 += 2
261-
xm.mark_step()
281+
met.clear_all()
282+
xla_device = xm.xla_device()
283+
t0 = torch.randn(4, 2, 2).to(xla_device)
284+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(t0, True))
285+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
286+
t0 += 2
287+
xm.mark_step()
262288

263-
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
289+
self.assertEqual(met.metric_data("InputOutputAliasCount")[1], 1.0)
264290

265291
def test_user_config_donation(self):
266292
with alias_with_buffer_donor_config_context(True):
@@ -304,6 +330,15 @@ def test_user_config_donation_no_op_mark_step(self):
304330
xm.mark_step()
305331
self.assertTrue(torch_xla._XLAC._get_buffer_donation(t0))
306332

333+
def test_no_op_mark_step_keep_buffer_donation(self):
334+
xla_device = xm.xla_device()
335+
input = torch.randn(5, 5).to(xla_device)
336+
self.assertTrue(torch_xla._XLAC._set_buffer_donation(input, True))
337+
xm.mark_step()
338+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
339+
xm.mark_step()
340+
self.assertTrue(torch_xla._XLAC._get_buffer_donation(input))
341+
307342

308343
if __name__ == '__main__':
309344
test = unittest.main()

torch_xla/csrc/xla_graph_executor.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1305,9 +1305,11 @@ std::vector<size_t> XLAGraphExecutor::GetBufferDonors(
13051305
return {};
13061306
}
13071307

1308+
bool donate_ltc_data =
1309+
coll.config.sync_ltc_data && coll.config.force_ltc_data;
13081310
std::vector<size_t> ltc_buffer_donor_indices;
1309-
if (coll.config.sync_ltc_data && coll.config.force_ltc_data) {
1310-
// We can only alias at the step barrier, when force_ltc_data is true.
1311+
if (donate_ltc_data) {
1312+
// We can only alias at the step barrier, when donate_ltc_data is true.
13111313
// Consider the case:
13121314
// 1. Tensor A(DEVICE_DATA)
13131315
// 2. Tensor B = A + 0.9
@@ -1336,7 +1338,10 @@ std::vector<size_t> XLAGraphExecutor::GetBufferDonors(
13361338
}
13371339

13381340
std::vector<size_t> user_config_buffer_donor_indices;
1339-
if (GetAliasWithBufferDonorConfig()) {
1341+
if (donate_ltc_data || GetAliasWithBufferDonorConfig()) {
1342+
// In case any tensor is explicitly marked for donation, we ensure that it
1343+
// is donated during step barrier, or if explicitly forced to donate via
1344+
// GetAliasWithBufferDonorConfig().
13401345
user_config_buffer_donor_indices =
13411346
GetBufferDonorIndexFromUserConfig(parameters_data);
13421347
}

0 commit comments

Comments
 (0)