20
20
#include " tc/core/polyhedral/cuda/mapped_scop.h"
21
21
#include " tc/core/polyhedral/cuda/mapping_types.h"
22
22
#include " tc/core/polyhedral/exceptions.h"
23
- #include " tc/core/polyhedral/functional.h"
24
- #include " tc/core/polyhedral/schedule_tree.h"
25
- #include " tc/core/polyhedral/schedule_utils.h"
26
- #include " tc/core/polyhedral/scop.h"
27
23
28
24
namespace tc {
29
25
namespace polyhedral {
30
26
namespace {
31
- // This returns the (inclusive) range of the mapping parameter "mappingId"
32
- // within the context "mappingContext".
33
- // This range corresponds to the blocks/threads active at the particular
34
- // location in the tree where this mapping is active.
35
- //
36
- // This is used to tighten the kernel to only launch on the necessary amount
37
- // of resources.
38
- //
39
- // When the range is unbounded on the right, we return the maximal positive
40
- // range (0, max_size_t). This needs to be intersected with launch bounds to
41
- // obtain the proper finite range.
42
- // Otherwise, the range is asserted bounded on the left and to lie in the
43
- // positive half of the integer axis.
44
- std::pair<size_t , size_t > rangeOfMappingParameter (
45
- isl::set mappingContext,
46
- mapping::MappingId mappingId) {
47
- if (!mappingContext.involves_param (mappingId)) {
48
- return std::make_pair (0 , std::numeric_limits<size_t >::max ());
49
- }
50
- auto space = mappingContext.get_space ();
51
- isl::aff a (isl::aff::param_on_domain_space (space, mappingId));
52
- auto max = mappingContext.max_val (a);
53
- if (max.is_nan () || max.is_infty ()) {
54
- return std::make_pair (0 , std::numeric_limits<size_t >::max ());
55
- }
56
- TC_CHECK (max.is_int ()) << max.to_str ();
57
- TC_CHECK (max.is_nonneg ()) << max.to_str ();
58
- auto min = mappingContext.min_val (a);
59
- TC_CHECK (min.is_int ()) << max.to_str ();
60
- TC_CHECK (min.is_nonneg ()) << max.to_str ();
61
-
62
- return std::make_pair (
63
- static_cast <size_t >(min.get_num_si ()),
64
- static_cast <size_t >(max.get_num_si ()));
65
- }
66
-
67
27
/*
68
- * Compute the maximal value attained by the mapping parameter "id".
28
+ * Return the mapping to MappingTypeId, i.e, either the mapping to blocks or
29
+ * the mapping to threads.
69
30
*/
70
- template <typename MappingIdType>
71
- size_t maxValue (const Scop& scop, const MappingIdType& id) {
72
- using namespace polyhedral ::detail;
73
-
74
- auto root = scop.scheduleRoot ();
75
- auto params = scop.context ();
76
- size_t sizetMax = std::numeric_limits<size_t >::max ();
77
- size_t max = 0 ;
78
- size_t min = sizetMax;
79
- auto filters = root->collect (root, ScheduleTreeType::Mapping);
80
- filters = functional::Filter (isMappingTo<MappingIdType>, filters);
81
- for (auto p : filters) {
82
- auto mappingNode = p->as <ScheduleTreeMapping>();
83
- auto active = activeDomainPoints (root, p).intersect_params (params);
84
- active = active.intersect (mappingNode->filter_ );
85
- auto range = rangeOfMappingParameter (active.params (), id);
86
- min = std::min (min, range.first );
87
- max = std::max (max, range.second );
88
- }
89
- TC_CHECK (max < sizetMax) << " missing mapping to " << id << " \n " << *root;
90
- TC_CHECK (min < sizetMax) << " missing mapping to " << id << " type\n " << *root;
91
- // Inclusive range needs + 1 to translate to sizes
92
- return max + 1 ;
31
+ template <typename MappingTypeId>
32
+ static isl::multi_union_pw_aff mappingSchedule (const MappedScop& mscop);
33
+ template <>
34
+ isl::multi_union_pw_aff mappingSchedule<mapping::BlockId>(
35
+ const MappedScop& mscop) {
36
+ return mscop.blockMappingSchedule (mscop.schedule ());
37
+ }
38
+ template <>
39
+ isl::multi_union_pw_aff mappingSchedule<mapping::ThreadId>(
40
+ const MappedScop& mscop) {
41
+ return mscop.threadMappingSchedule (mscop.schedule ());
93
42
}
94
43
95
44
/*
@@ -100,8 +49,17 @@ template <typename MappingIdType, typename Size>
100
49
Size launchBounds (const MappedScop& mscop, Size size) {
101
50
Size tightened;
102
51
52
+ auto params = mscop.scop ().context ();
53
+ auto mapping = mappingSchedule<MappingIdType>(mscop);
54
+ mapping = mapping.intersect_params (params);
55
+ auto max = mapping.max_multi_val ();
56
+
103
57
for (size_t i = 0 ; i < size.view .size (); ++i) {
104
- tightened.view [i] = maxValue (mscop.scop (), MappingIdType::makeId (i));
58
+ auto maxVal = max.get_val (i);
59
+ TC_CHECK (maxVal.is_int ()) << maxVal.to_str ();
60
+ TC_CHECK (maxVal.is_nonneg ()) << maxVal.to_str ();
61
+ // Inclusive range needs + 1 to translate to sizes
62
+ tightened.view [i] = maxVal.get_num_si () + 1 ;
105
63
}
106
64
107
65
return tightened;
0 commit comments