26
26
namespace tc {
27
27
namespace polyhedral {
28
28
namespace {
29
- // This returns the (inclusive) range of the mapping parameter that is active
30
- // at node under root given:
31
- // 1. a context that is the intersection of the specialization context and
32
- // the mapping context
33
- // 2. a MappingId
34
- // This range corresponds to the blocks/threads active at that particular
35
- // location in the tree.
29
+ // This returns the (inclusive) range of the mapping parameter "mappingId"
30
+ // within the context "mappingContext".
31
+ // This range corresponds to the blocks/threads active at the particular
32
+ // location in the tree where this mapping is active.
36
33
//
37
34
// This is used to tighten the kernel to only launch on the necessary amount
38
35
// of resources.
@@ -43,23 +40,20 @@ namespace {
43
40
// Otherwise, the range is asserted bounded on the left and to lie in the
44
41
// positive half of the integer axis.
45
42
std::pair<size_t , size_t > rangeOfMappingParameter (
46
- const detail::ScheduleTree* root,
47
- const detail::ScheduleTree* node,
48
- isl::set context,
43
+ isl::set mappingContext,
49
44
mapping::MappingId mappingId) {
50
- auto active =
51
- activeDomainPoints (root, node).intersect_params (context).params ();
52
- if (!active.involves_param (mappingId)) {
45
+ if (!mappingContext.involves_param (mappingId)) {
53
46
return std::make_pair (0 , std::numeric_limits<size_t >::max ());
54
47
}
55
- isl::aff a (isl::aff::param_on_domain_space (active.get_space (), mappingId));
56
- auto max = active.max_val (a);
48
+ auto space = mappingContext.get_space ();
49
+ isl::aff a (isl::aff::param_on_domain_space (space, mappingId));
50
+ auto max = mappingContext.max_val (a);
57
51
if (max.is_nan () || max.is_infty ()) {
58
52
return std::make_pair (0 , std::numeric_limits<size_t >::max ());
59
53
}
60
54
TC_CHECK (max.is_int ()) << max.to_str ();
61
55
TC_CHECK (max.is_nonneg ()) << max.to_str ();
62
- auto min = active .min_val (a);
56
+ auto min = mappingContext .min_val (a);
63
57
TC_CHECK (min.is_int ()) << max.to_str ();
64
58
TC_CHECK (min.is_nonneg ()) << max.to_str ();
65
59
@@ -68,13 +62,52 @@ std::pair<size_t, size_t> rangeOfMappingParameter(
68
62
static_cast <size_t >(max.get_num_si ()));
69
63
}
70
64
71
- // Look for nodes with no children.
72
- inline std::vector<const detail::ScheduleTree*> leaves (
73
- const detail::ScheduleTree* tree) {
74
- return functional::Filter (
75
- [](const detail::ScheduleTree* st) { return st->numChildren () == 0 ; },
76
- detail::ScheduleTree::collect (tree));
65
+ /*
66
+ * Compute the maximal value attained by the mapping parameter "id".
67
+ */
68
+ template <typename MappingIdType>
69
+ size_t maxValue (const Scop& scop, const MappingIdType& id) {
70
+ using namespace polyhedral ::detail;
71
+
72
+ auto root = scop.scheduleRoot ();
73
+ auto params = scop.context ();
74
+ size_t sizetMax = std::numeric_limits<size_t >::max ();
75
+ size_t max = 0 ;
76
+ size_t min = sizetMax;
77
+ auto filters = root->collect (root, ScheduleTreeType::Mapping);
78
+ filters = functional::Filter (isMappingTo<MappingIdType>, filters);
79
+ for (auto p : filters) {
80
+ auto mappingNode = p->elemAs <ScheduleTreeElemMapping>();
81
+ auto active = activeDomainPoints (root, p).intersect_params (params);
82
+ active = active.intersect (mappingNode->filter_ );
83
+ auto range = rangeOfMappingParameter (active.params (), id);
84
+ min = std::min (min, range.first );
85
+ max = std::max (max, range.second );
86
+ }
87
+ // Ignore min for now but there is a future possibility for shifting
88
+ LOG_IF (WARNING, min > 0 )
89
+ << " Opportunity for tightening launch bounds with shifting -> min:"
90
+ << min;
91
+ TC_CHECK (max < sizetMax) << " missing mapping to " << id << *root;
92
+ // Inclusive range needs + 1 to translate to sizes
93
+ return max + 1 ;
94
+ }
95
+
96
+ /*
97
+ * Take grid or block launch bounds "size" and replace them
98
+ * by the tightened, actual, launch bounds used in practice.
99
+ */
100
+ template <typename MappingIdType, typename Size>
101
+ Size launchBounds (const Scop& scop, Size size) {
102
+ Size tightened;
103
+
104
+ for (size_t i = 0 ; i < size.view .size (); ++i) {
105
+ tightened.view [i] = maxValue (scop, MappingIdType::makeId (i));
106
+ }
107
+
108
+ return tightened;
77
109
}
110
+
78
111
} // namespace
79
112
80
113
// Takes grid/block launch bounds that have been passed to mapping and
@@ -84,56 +117,9 @@ std::pair<tc::Grid, tc::Block> tightenLaunchBounds(
84
117
const Scop& scop,
85
118
const tc::Grid& grid,
86
119
const tc::Block& block) {
87
- auto root = scop.scheduleRoot ();
88
- auto params = scop.context ();
89
-
90
- auto max = [root, params](const mapping::MappingId& id) -> size_t {
91
- size_t sizetMax = std::numeric_limits<size_t >::max ();
92
- size_t max = 0 ;
93
- size_t min = sizetMax;
94
- auto nonSyncLeaves = functional::Filter (
95
- [root, params](const detail::ScheduleTree* node) {
96
- auto f = node->elemAsBase <detail::ScheduleTreeElemFilter>();
97
- if (!f) {
98
- return true ;
99
- }
100
- if (f->filter_ .n_set () != 1 ) {
101
- std::stringstream ss;
102
- ss << " In tree:\n "
103
- << *root << " \n not a single set in filter: " << f->filter_ ;
104
- throw tightening::TighteningException (ss.str ());
105
- }
106
- auto single = isl::set::from_union_set (f->filter_ );
107
- auto single_id = single.get_tuple_id ();
108
- return !Scop::isSyncId (single_id) && !Scop::isWarpSyncId (single_id);
109
- },
110
- leaves (root));
111
- for (auto p : nonSyncLeaves) {
112
- auto range = rangeOfMappingParameter (root, p, params, id);
113
- min = std::min (min, range.first );
114
- max = std::max (max, range.second );
115
- }
116
- // Ignore min for now but there is a future possibility for shifting
117
- LOG_IF (WARNING, min > 0 )
118
- << " Opportunity for tightening launch bounds with shifting -> min:"
119
- << min;
120
- // Inclusive range needs + 1 to translate to sizes
121
- if (max < sizetMax) { // avoid overflow
122
- return max + 1 ;
123
- }
124
- return sizetMax;
125
- };
126
-
127
- USING_MAPPING_SHORT_NAMES (BX, BY, BZ, TX, TY, TZ);
128
- // Corner case: take the min with the current size to avoid degenerate
129
- // range in the unbounded case.
130
120
return std::make_pair (
131
- tc::Grid ({std::min (max (BX), BX.mappingSize (grid)),
132
- std::min (max (BY), BY.mappingSize (grid)),
133
- std::min (max (BZ), BZ.mappingSize (grid))}),
134
- tc::Block ({std::min (max (TX), TX.mappingSize (block)),
135
- std::min (max (TY), TY.mappingSize (block)),
136
- std::min (max (TZ), TZ.mappingSize (block))}));
121
+ launchBounds<mapping::BlockId>(scop, grid),
122
+ launchBounds<mapping::ThreadId>(scop, block));
137
123
}
138
124
} // namespace polyhedral
139
125
} // namespace tc
0 commit comments