1
+ import contextlib
2
+ import copy
1
3
import os
2
4
import sys
5
+ import unittest
6
+ from absl .testing import parameterized
3
7
4
8
import torch
5
9
import torch .nn as nn
6
10
import torch .nn .functional as F
7
11
import torch_xla
8
12
import torch_xla .core .xla_model as xm
9
13
import torch_xla .debug .metrics as met
10
- import unittest
11
- import contextlib
12
- import copy
13
14
14
15
15
16
def create_xla_config_context (set_func , get_func ):
@@ -34,7 +35,7 @@ def config_context(value):
34
35
35
36
36
37
# TODO(alanwaketan): add test for views.
37
- class InputOutputAliasesTest (unittest .TestCase ):
38
+ class InputOutputAliasesTest (parameterized .TestCase ):
38
39
39
40
def test_non_view (self ):
40
41
xla_device = xm .xla_device ()
@@ -233,34 +234,59 @@ def test_device_data_cache_no_aliasing(self):
233
234
self .assertEqual (t1 .item (), 43 )
234
235
235
236
def test_user_config_donation_with_ltc_donation (self ):
236
- with alias_with_buffer_donor_config_context (True ):
237
+ met .clear_all ()
238
+ xla_device = xm .xla_device ()
239
+ t0 = torch .randn (4 , 2 , 2 ).to (xla_device )
240
+ t1 = torch .randn (4 , 2 , 2 ).to (xla_device )
241
+ self .assertTrue (torch_xla ._XLAC ._set_buffer_donation (t0 , True ))
242
+ self .assertTrue (torch_xla ._XLAC ._get_buffer_donation (t0 ))
243
+ self .assertFalse (torch_xla ._XLAC ._get_buffer_donation (t1 ))
244
+ t2 = t0 + t1
245
+ t1 += 2
246
+ xm .mark_step (wait = True )
247
+
248
+ # We surface the C++ runtime error by checking that the backend data is
249
+ # no longer present for the IR node.
250
+ self .assertTrue (torch_xla ._XLAC ._is_placecholder (t0 ))
251
+ self .assertEqual (met .metric_data ("InputOutputAliasCount" )[1 ], 2.0 )
252
+
253
+ @parameterized .parameters (True , False )
254
+ def test_user_config_donation_with_ltc_donation_graph_sync (
255
+ self , enable_buffer_donor_config ):
256
+ with alias_with_buffer_donor_config_context (enable_buffer_donor_config ):
237
257
met .clear_all ()
238
258
xla_device = xm .xla_device ()
239
259
t0 = torch .randn (4 , 2 , 2 ).to (xla_device )
240
260
t1 = torch .randn (4 , 2 , 2 ).to (xla_device )
241
261
self .assertTrue (torch_xla ._XLAC ._set_buffer_donation (t0 , True ))
242
262
self .assertTrue (torch_xla ._XLAC ._get_buffer_donation (t0 ))
243
263
self .assertFalse (torch_xla ._XLAC ._get_buffer_donation (t1 ))
244
- t3 = t0 + t1
264
+ t2 = t0 + t1
245
265
t1 += 2
246
- xm .mark_step (wait = True )
266
+ # We use _xla_sync_multi to explicitly disable sync_xla_data, which will
267
+ # in turn avoid using LTC aliasings. This ensures that the resulting
268
+ # aliasings are due to the buffer donation.
269
+ torch_xla ._XLAC ._xla_sync_multi ([t0 , t1 , t2 ], [str (xla_device )], True ,
270
+ False )
247
271
248
272
# We surface the C++ runtime error by checking that the backend data is
249
273
# no longer present for the IR node.
250
- self .assertTrue (torch_xla ._XLAC ._is_placecholder (t0 ))
251
- self .assertEqual (met .metric_data ("InputOutputAliasCount" )[1 ], 2.0 )
274
+ self .assertEqual (
275
+ torch_xla ._XLAC ._is_placecholder (t0 ), enable_buffer_donor_config )
276
+ self .assertEqual (
277
+ met .metric_data ("InputOutputAliasCount" )[1 ],
278
+ enable_buffer_donor_config )
252
279
253
280
def test_user_config_donation_with_ltc_donation_overlap (self ):
254
- with alias_with_buffer_donor_config_context (True ):
255
- met .clear_all ()
256
- xla_device = xm .xla_device ()
257
- t0 = torch .randn (4 , 2 , 2 ).to (xla_device )
258
- self .assertTrue (torch_xla ._XLAC ._set_buffer_donation (t0 , True ))
259
- self .assertTrue (torch_xla ._XLAC ._get_buffer_donation (t0 ))
260
- t0 += 2
261
- xm .mark_step ()
281
+ met .clear_all ()
282
+ xla_device = xm .xla_device ()
283
+ t0 = torch .randn (4 , 2 , 2 ).to (xla_device )
284
+ self .assertTrue (torch_xla ._XLAC ._set_buffer_donation (t0 , True ))
285
+ self .assertTrue (torch_xla ._XLAC ._get_buffer_donation (t0 ))
286
+ t0 += 2
287
+ xm .mark_step ()
262
288
263
- self .assertEqual (met .metric_data ("InputOutputAliasCount" )[1 ], 1.0 )
289
+ self .assertEqual (met .metric_data ("InputOutputAliasCount" )[1 ], 1.0 )
264
290
265
291
def test_user_config_donation (self ):
266
292
with alias_with_buffer_donor_config_context (True ):
@@ -304,6 +330,15 @@ def test_user_config_donation_no_op_mark_step(self):
304
330
xm .mark_step ()
305
331
self .assertTrue (torch_xla ._XLAC ._get_buffer_donation (t0 ))
306
332
333
+ def test_no_op_mark_step_keep_buffer_donation (self ):
334
+ xla_device = xm .xla_device ()
335
+ input = torch .randn (5 , 5 ).to (xla_device )
336
+ self .assertTrue (torch_xla ._XLAC ._set_buffer_donation (input , True ))
337
+ xm .mark_step ()
338
+ self .assertTrue (torch_xla ._XLAC ._get_buffer_donation (input ))
339
+ xm .mark_step ()
340
+ self .assertTrue (torch_xla ._XLAC ._get_buffer_donation (input ))
341
+
307
342
308
343
if __name__ == '__main__' :
309
344
test = unittest .main ()
0 commit comments