@@ -446,7 +446,7 @@ def f(x, y):
446
446
f (x1 , x2 )
447
447
448
448
@jtu .with_mesh ([('x' , 2 ), ('y' , 1 )])
449
- def testShardingConstraintStablehlo (self ):
449
+ def testShardingConstraintMeshContext (self ):
450
450
@partial (pjit , in_shardings = None , out_shardings = None )
451
451
def f (x ):
452
452
y = x + 1
@@ -463,17 +463,17 @@ def f(x):
463
463
self .assertAllClose (np .asarray (actual .addressable_shards [0 ].data ), expected ,
464
464
check_dtypes = False )
465
465
466
- hlo = f .lower (np .ones (shape )).compiler_ir ()
466
+ lowered_text = f .lower (np .ones (shape )).as_text ()
467
467
if config .use_shardy_partitioner .value :
468
468
# Annotation from with_sharding_constraint
469
- self .assertIn ('<@mesh, [{"x"}, {"y"}]>' , str ( hlo ) )
469
+ self .assertIn ('<@mesh, [{"x"}, {"y"}]>' , lowered_text )
470
470
# Annotation from pjit
471
- self .assertIn ('sharding = #sdy.sharding<@mesh, [{}, {}]>}' , str ( hlo ) )
471
+ self .assertIn ('sdy. sharding = #sdy.sharding<@mesh, [{}, {}]>}' , lowered_text )
472
472
else :
473
473
# Annotation from with_sharding_constraint
474
- self .assertIn ('sharding = "{devices=[2,1]<=[2]}"' , str ( hlo ) )
474
+ self .assertIn ('mhlo. sharding = "{devices=[2,1]<=[2]}"' , lowered_text )
475
475
# Annotation from pjit
476
- self .assertIn ('sharding = "{replicated}"' , str ( hlo ) )
476
+ self .assertIn ('mhlo. sharding = "{replicated}"' , lowered_text )
477
477
478
478
def testShardingConstraintWithArray (self ):
479
479
mesh = jtu .create_mesh ((2 , 1 ), ('x' , 'y' ))
@@ -494,11 +494,17 @@ def f(x):
494
494
self .assertLen (actual .addressable_shards , 2 )
495
495
self .assertAllClose (actual , expected , check_dtypes = False )
496
496
497
- hlo = f .lower (np .ones (shape )).compiler_ir (dialect = "hlo" )
498
- # Annotation from with_sharding_constraint
499
- self .assertIn ('sharding={devices=[2,1]<=[2]}' , hlo .as_hlo_text ())
500
- # Annotation from pjit
501
- self .assertIn ("sharding={replicated}" , hlo .as_hlo_text ())
497
+ lowered_text = f .lower (np .ones (shape )).as_text ()
498
+ if config .use_shardy_partitioner .value :
499
+ # Annotation from with_sharding_constraint
500
+ self .assertIn ('<@mesh, [{"x"}, {"y"}]>' , lowered_text )
501
+ # Annotation from pjit
502
+ self .assertIn ('sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}' , lowered_text )
503
+ else :
504
+ # Annotation from with_sharding_constraint
505
+ self .assertIn ('mhlo.sharding = "{devices=[2,1]<=[2]}"' , lowered_text )
506
+ # Annotation from pjit
507
+ self .assertIn ('mhlo.sharding = "{replicated}"' , lowered_text )
502
508
503
509
def testShardingConstraintWithArrayOpSharding (self ):
504
510
shape = (8 , 8 )
@@ -521,11 +527,18 @@ def f(x):
521
527
self .assertLen (actual .addressable_shards , 2 )
522
528
self .assertAllClose (actual , expected , check_dtypes = False )
523
529
524
- hlo = f .lower (np .ones (shape )).compiler_ir (dialect = "hlo" )
525
- # Annotation from with_sharding_constraint
526
- self .assertIn ('sharding={devices=[2,1]<=[2]}' , hlo .as_hlo_text ())
527
- # Annotation from pjit
528
- self .assertIn ("sharding={replicated}" , hlo .as_hlo_text ())
530
+ lowered_text = f .lower (np .ones (shape )).as_text ()
531
+ if config .use_shardy_partitioner .value :
532
+ # Annotation from with_sharding_constraint, translated from GSPMD to SDY
533
+ self .assertIn ('@mesh_0 = <["_axis_0"=2]>' , lowered_text )
534
+ self .assertIn ('<@mesh_0, [{"_axis_0"}, {}]>' , lowered_text )
535
+ # Annotation from pjit
536
+ self .assertIn ('sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}' , lowered_text )
537
+ else :
538
+ # Annotation from with_sharding_constraint
539
+ self .assertIn ('mhlo.sharding = "{devices=[2,1]<=[2]}"' , lowered_text )
540
+ # Annotation from pjit
541
+ self .assertIn ('mhlo.sharding = "{replicated}"' , lowered_text )
529
542
530
543
def testShardingConstraintPyTreeWithArray (self ):
531
544
mesh = jtu .create_mesh ((2 , 1 ), ('x' , 'y' ))
@@ -544,13 +557,15 @@ def f(x):
544
557
self .assertLen (out [0 ].addressable_shards , 2 )
545
558
self .assertLen (out [1 ].addressable_shards , 2 )
546
559
547
- hlo = f .lower (x ).compiler_ir (dialect = "hlo" )
548
- # Annotations from with_sharding_constraint
549
- self .assertIn ('sharding={devices=[2,1]<=[2]}' , hlo .as_hlo_text ())
550
- self .assertIn ('sharding={devices=[2,1]<=[2]}' , hlo .as_hlo_text ())
560
+ lowered_text = f .lower (x ).as_text ()
561
+ if config .use_shardy_partitioner .value :
562
+ # Annotation from with_sharding_constraint
563
+ self .assertIn ('<@mesh, [{"x"}, {"y"}]>' , lowered_text )
564
+ else :
565
+ # Annotations from with_sharding_constraint
566
+ self .assertIn ('mhlo.sharding = "{devices=[2,1]<=[2]}"' , lowered_text )
551
567
552
568
def testShardingConstraintPyTreeWithUnconstrainedDimsWithJit (self ):
553
-
554
569
mesh = jtu .create_mesh ((2 , 2 ), ('x' , 'y' ))
555
570
@jax .jit
556
571
def f (x ):
@@ -571,17 +586,16 @@ def f(x):
571
586
self .assertAllClose (actual , expected , check_dtypes = False )
572
587
self .assertLen (actual [0 ]['a' ].addressable_shards , 4 )
573
588
574
- mlir_str = str ( f .lower (x ).compiler_ir () )
589
+ lowered_text = f .lower (x ).as_text ( )
575
590
if config .use_shardy_partitioner .value :
576
- self .assertIn ('<@mesh, [{?}, {"y"}, {}]>' , mlir_str )
577
- self .assertIn ('<@mesh, [{"x"}, {?}, {}]>' , mlir_str )
591
+ self .assertIn ('<@mesh, [{?}, {"y"}, {}]>' , lowered_text )
592
+ self .assertIn ('<@mesh, [{"x"}, {?}, {}]>' , lowered_text )
578
593
else :
579
- self .assertIn ("unspecified_dims=[0]" , mlir_str )
580
- self .assertIn ("unspecified_dims=[1]" , mlir_str )
594
+ self .assertIn ("unspecified_dims=[0]" , lowered_text )
595
+ self .assertIn ("unspecified_dims=[1]" , lowered_text )
581
596
582
597
@jtu .with_mesh ([('x' , 2 ), ('y' , 2 )])
583
598
def testShardingConstraintPyTreeVmapWithUnconstrainedDims (self ):
584
-
585
599
@partial (pjit , in_shardings = None , out_shardings = None )
586
600
def f (x ):
587
601
x = jax .vmap (lambda x : with_sharding_constraint (
@@ -595,13 +609,13 @@ def f(x):
595
609
v = np .arange (math .prod (shape )).reshape (shape )
596
610
x = [{'a' : v , 'b' : v * 2 }, v * 3 ]
597
611
598
- mlir_str = str ( f .lower (x ).compiler_ir () )
612
+ lowered_text = f .lower (x ).as_text ( )
599
613
if config .use_shardy_partitioner .value :
600
- self .assertIn ('<@mesh, [{?}, {?}, {"y"}]>' , mlir_str )
601
- self .assertIn ('<@mesh, [{?}, {"x"}, {?}]>' , mlir_str )
614
+ self .assertIn ('<@mesh, [{?}, {?}, {"y"}]>' , lowered_text )
615
+ self .assertIn ('<@mesh, [{?}, {"x"}, {?}]>' , lowered_text )
602
616
else :
603
- self .assertIn ("unspecified_dims=[0,1]" , mlir_str )
604
- self .assertIn ("unspecified_dims=[0,2]" , mlir_str )
617
+ self .assertIn ("unspecified_dims=[0,1]" , lowered_text )
618
+ self .assertIn ("unspecified_dims=[0,2]" , lowered_text )
605
619
606
620
def testCaching (self ):
607
621
def f (x ):
0 commit comments