41
41
#include < Tensile/Contractions.hpp>
42
42
#include < Tensile/EmbeddedLibrary.hpp>
43
43
#include < Tensile/MasterSolutionLibrary.hpp>
44
+ #include < Tensile/PlaceholderLibrary.hpp>
44
45
#include < Tensile/Tensile.hpp>
45
46
#include < Tensile/TensorDescriptor.hpp>
46
47
#include < Tensile/Utils.hpp>
@@ -450,14 +451,94 @@ namespace
450
451
return inputs;
451
452
}
452
453
454
+ TensileLite::LazyLoadingInit getLazyLoadingArch (int deviceID)
455
+ {
456
+ hipDeviceProp_t deviceProperties;
457
+ HIP_CHECK_EXC (hipGetDeviceProperties (&deviceProperties, deviceID));
458
+ // strip out xnack/ecc from name
459
+ std::string deviceFullString (deviceProperties.gcnArchName );
460
+ std::string deviceString = deviceFullString.substr (0 , deviceFullString.find (" :" ));
461
+
462
+ if (deviceString.find (" gfx803" ) != std::string::npos)
463
+ {
464
+ return TensileLite::LazyLoadingInit::gfx803;
465
+ }
466
+ else if (deviceString.find (" gfx900" ) != std::string::npos)
467
+ {
468
+ return TensileLite::LazyLoadingInit::gfx900;
469
+ }
470
+ else if (deviceString.find (" gfx906" ) != std::string::npos)
471
+ {
472
+ return TensileLite::LazyLoadingInit::gfx906;
473
+ }
474
+ else if (deviceString.find (" gfx908" ) != std::string::npos)
475
+ {
476
+ return TensileLite::LazyLoadingInit::gfx908;
477
+ }
478
+ else if (deviceString.find (" gfx90a" ) != std::string::npos)
479
+ {
480
+ return TensileLite::LazyLoadingInit::gfx90a;
481
+ }
482
+ else if (deviceString.find (" gfx940" ) != std::string::npos)
483
+ {
484
+ return TensileLite::LazyLoadingInit::gfx940;
485
+ }
486
+ else if (deviceString.find (" gfx941" ) != std::string::npos)
487
+ {
488
+ return TensileLite::LazyLoadingInit::gfx941;
489
+ }
490
+ else if (deviceString.find (" gfx942" ) != std::string::npos)
491
+ {
492
+ return TensileLite::LazyLoadingInit::gfx942;
493
+ }
494
+ else if (deviceString.find (" gfx1010" ) != std::string::npos)
495
+ {
496
+ return TensileLite::LazyLoadingInit::gfx1010;
497
+ }
498
+ else if (deviceString.find (" gfx1011" ) != std::string::npos)
499
+ {
500
+ return TensileLite::LazyLoadingInit::gfx1011;
501
+ }
502
+ else if (deviceString.find (" gfx1012" ) != std::string::npos)
503
+ {
504
+ return TensileLite::LazyLoadingInit::gfx1012;
505
+ }
506
+ else if (deviceString.find (" gfx1030" ) != std::string::npos)
507
+ {
508
+ return TensileLite::LazyLoadingInit::gfx1030;
509
+ }
510
+ else if (deviceString.find (" gfx1100" ) != std::string::npos)
511
+ {
512
+ return TensileLite::LazyLoadingInit::gfx1100;
513
+ }
514
+ else if (deviceString.find (" gfx1101" ) != std::string::npos)
515
+ {
516
+ return TensileLite::LazyLoadingInit::gfx1101;
517
+ }
518
+ else if (deviceString.find (" gfx1102" ) != std::string::npos)
519
+ {
520
+ return TensileLite::LazyLoadingInit::gfx1102;
521
+ }
522
+ else if (deviceString.find (" gfx1200" ) != std::string::npos)
523
+ {
524
+ return TensileLite::LazyLoadingInit::gfx1200;
525
+ }
526
+ else if (deviceString.find (" gfx1201" ) != std::string::npos)
527
+ {
528
+ return TensileLite::LazyLoadingInit::gfx1201;
529
+ }
530
+ return TensileLite::LazyLoadingInit::None;
531
+ }
532
+
453
533
/* *************************************************
454
534
* The TensileHost struct interfaces with Tensile *
455
535
**************************************************/
456
536
class TensileHost
457
537
{
458
538
// The library object
459
539
std::shared_ptr<TensileLite::MasterSolutionLibrary<TensileLite::ContractionProblemGemm>> m_library;
460
- std::shared_ptr<hipDeviceProp_t> m_deviceProp;
540
+ std::unordered_set<TensileLite::LazyLoadingInit> m_deviceSet;
541
+ std::unordered_map<std::string, std::shared_ptr<hipDeviceProp_t>> m_devicePropMap;
461
542
462
543
// The adapter object. mutable is used to allow adapters to be modified
463
544
// even when they are stored in a const vector which is immutable in size
@@ -508,9 +589,9 @@ namespace
508
589
return m_library;
509
590
}
510
591
511
- auto & get_device_property () const
592
+ auto & get_device_property (const std::string& deviceName ) const
512
593
{
513
- return m_deviceProp ;
594
+ return m_devicePropMap. at (deviceName) ;
514
595
}
515
596
516
597
auto & get_adapters () const
@@ -576,7 +657,7 @@ namespace
576
657
577
658
// only load modules for the current architecture
578
659
auto dir = path + " /*" + processor + " *co" ;
579
-
660
+ # if ROCSPARSELT_TENSILE_LAZY_LOAD == 0
580
661
bool no_match = false ;
581
662
#ifdef WIN32
582
663
std::replace (dir.begin (), dir.end (), ' /' , ' \\ ' );
@@ -630,28 +711,59 @@ namespace
630
711
<< std::endl;
631
712
(void )once;
632
713
}
633
-
714
+ # endif // ROCSPARSELT_TENSILE_LAZY_LOAD == 0
634
715
// We initialize a local static variable with a lambda function call to avoid
635
716
// race conditions when multiple threads with different device IDs try to
636
717
// initialize library. This ensures that only one thread initializes library,
637
718
// and other threads trying to initialize library wait for it to complete.
638
719
static int once = [&] {
720
+ // Determine library path
721
+ std::string tensileLibPath;
722
+ #if ROCSPARSELT_TENSILE_LAZY_LOAD
723
+ #ifdef TENSILE_YAML
724
+ tensileLibPath = path + " /TensileLibrary_lazy_" + processor + " .yaml" ;
725
+ #else
726
+ tensileLibPath = path + " /TensileLibrary_lazy_" + processor + " .dat" ;
727
+ #endif
728
+ #else
639
729
#ifdef TENSILE_YAML
640
- path += " /TensileLibrary .yaml" ;
730
+ tensileLibPath = path + " /TensileLibrary_ " + processor + " .yaml" ;
641
731
#else
642
- path += " /TensileLibrary.dat" ;
732
+ tensileLibPath = path + " /TensileLibrary_" + processor + " .dat" ;
733
+ #endif
643
734
#endif
644
- if (!TestPath (path ))
735
+ if (!TestPath (tensileLibPath ))
645
736
{
646
- hipsparselt_cerr << " \n hipsparselt_error: Cannot read " << path << " : "
737
+ hipsparselt_cerr << " \n hipsparselt_error: Cannot read " << tensileLibPath << " : "
647
738
<< strerror (errno) << std::endl;
648
739
// rocsparselt_abort();
649
740
}
650
741
651
- auto lib = TensileLite::LoadLibraryFile<TensileLite::ContractionProblemGemm>(path);
742
+ // Get devices
743
+ hipDeviceProp_t prop;
744
+ int count;
745
+ HIP_CHECK_EXC (hipGetDeviceCount (&count));
746
+ for (int devId = 0 ; devId < count; devId++)
747
+ {
748
+ auto deviceArch = getLazyLoadingArch (devId);
749
+ if (m_deviceSet.find (deviceArch) == m_deviceSet.end ())
750
+ {
751
+ // populate the arch list for lazy loading
752
+ m_deviceSet.insert (deviceArch);
753
+ // populate device property map, used in finding solutions based on arch
754
+ HIP_CHECK_EXC (hipGetDeviceProperties (&prop, devId));
755
+ // strip out xnack/ecc from name
756
+ std::string deviceFullString (prop.gcnArchName );
757
+ std::string deviceString
758
+ = deviceFullString.substr (0 , deviceFullString.find (" :" ));
759
+ m_devicePropMap[deviceString] = std::make_shared<hipDeviceProp_t>(prop);
760
+ }
761
+ }
762
+
763
+ auto lib = TensileLite::LoadLibraryFile<TensileLite::ContractionProblemGemm>(tensileLibPath);
652
764
if (!lib)
653
765
{
654
- hipsparselt_cerr << " \n hipsparselt_error: Could not load " << path << std::endl;
766
+ hipsparselt_cerr << " \n hipsparselt_error: Could not load " << tensileLibPath << std::endl;
655
767
return -1 ;
656
768
}
657
769
else
@@ -662,17 +774,15 @@ namespace
662
774
return 0 ;
663
775
}();
664
776
777
+ static_cast <void >(adapter.initializeLazyLoading (processor, path));
778
+
779
+
665
780
if (!m_library && once != 0 )
666
781
{
667
782
hipsparselt_cerr << " \n hipsparselt_error: Could not initialize Tensile library"
668
783
<< std::endl;
669
784
// rocsparselt_abort();
670
785
}
671
-
672
- hipDeviceProp_t prop;
673
- THROW_IF_HIP_ERROR (hipGetDeviceProperties (&prop, deviceId));
674
-
675
- m_deviceProp = std::make_shared<hipDeviceProp_t>(prop);
676
786
}
677
787
};
678
788
@@ -719,7 +829,7 @@ namespace
719
829
if (library)
720
830
*library = host.get_library ();
721
831
if (deviceProp)
722
- *deviceProp = host.get_device_property ();
832
+ *deviceProp = host.get_device_property (rocsparselt_internal_get_arch_name () );
723
833
724
834
return *adapter;
725
835
}
@@ -919,6 +1029,11 @@ rocsparselt_status getBestSolutions(const RocsparseltContractionProblem<Ti, To,
919
1029
// auto &adapter =
920
1030
get_library_and_adapter (&library, &deviceProp, prob.handle ->device );
921
1031
1032
+ if (!library)
1033
+ {
1034
+ return rocsparselt_status_invalid_pointer;
1035
+ }
1036
+
922
1037
hardware = TensileLite::hip::GetDevice (*deviceProp);
923
1038
auto tensile_prob = ConstructTensileProblem (prob);
924
1039
// auto handle = prob.handle;
0 commit comments