Skip to content

Commit 66786a7

Browse files
authored
[MLIR][Presburger] Implement Matrix::moveColumns (#68362)
1 parent 1e51b35 commit 66786a7

File tree

3 files changed

+108
-4
lines changed

3 files changed

+108
-4
lines changed

mlir/include/mlir/Analysis/Presburger/Matrix.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,24 @@ class Matrix {
192192
/// invariants satisfied.
193193
bool hasConsistentState() const;
194194

195+
/// Move the columns in the source range [srcPos, srcPos + num) to the
196+
/// specified destination [dstPos, dstPos + num), while moving the columns
197+
/// adjacent to the source range to the left/right of the shifted columns.
198+
///
199+
/// When moving the source columns right (i.e. dstPos > srcPos), columns that
200+
/// were at positions [0, srcPos) and [dstPos + num, nCols) will stay where
201+
/// they are; columns that were at positions [srcPos, srcPos + num) will be
202+
/// moved to [dstPos, dstPos + num); and columns that were at positions
203+
/// [srcPos + num, dstPos + num) will be moved to [srcPos, dstPos).
204+
/// Equivalently, the columns [srcPos + num, dstPos + num) are interchanged
205+
/// with [srcPos, srcPos + num).
206+
/// For example, if m = |0 1 2 3 4 5| then:
207+
/// m.moveColumns(1, 3, 2) will result in m = |0 4 1 2 3 5|; or
208+
/// m.moveColumns(1, 2, 4) will result in m = |0 3 4 5 1 2|.
209+
///
210+
/// The left shift operation (i.e. dstPos < srcPos) works in a similar way.
211+
void moveColumns(unsigned srcPos, unsigned num, unsigned dstPos);
212+
195213
protected:
196214
/// The current number of rows, columns, and reserved columns. The underlying
197215
/// data vector is viewed as an nRows x nReservedColumns matrix, of which the

mlir/lib/Analysis/Presburger/Matrix.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,47 @@ void Matrix<T>::fillRow(unsigned row, const T &value) {
240240
at(row, col) = value;
241241
}
242242

243+
// moveColumns is implemented by moving the columns adjacent to the source range
244+
// to their final position. When moving right (i.e. dstPos > srcPos), the range
245+
// of the adjacent columns is [srcPos + num, dstPos + num). When moving left
246+
// (i.e. dstPos < srcPos) the range of the adjacent columns is [dstPos, srcPos).
247+
// First, zeroed out columns are inserted in the final positions of the adjacent
248+
// columns. Then, the adjacent columns are moved to their final positions by
249+
// swapping them with the zeroed columns. Finally, the now zeroed adjacent
250+
// columns are deleted.
251+
template <typename T>
252+
void Matrix<T>::moveColumns(unsigned srcPos, unsigned num, unsigned dstPos) {
253+
if (num == 0)
254+
return;
255+
256+
int offset = dstPos - srcPos;
257+
if (offset == 0)
258+
return;
259+
260+
assert(srcPos + num <= getNumColumns() &&
261+
"move source range exceeds matrix columns");
262+
assert(dstPos + num <= getNumColumns() &&
263+
"move destination range exceeds matrix columns");
264+
265+
unsigned insertCount = offset > 0 ? offset : -offset;
266+
unsigned finalAdjStart = offset > 0 ? srcPos : srcPos + num;
267+
unsigned curAdjStart = offset > 0 ? srcPos + num : dstPos;
268+
// TODO: This can be done using std::rotate.
269+
// Insert new zero columns in the positions where the adjacent columns are to
270+
// be moved.
271+
insertColumns(finalAdjStart, insertCount);
272+
// Update curAdjStart if insertion of new columns invalidates it.
273+
if (finalAdjStart < curAdjStart)
274+
curAdjStart += insertCount;
275+
276+
// Swap the adjacent columns with inserted zero columns.
277+
for (unsigned i = 0; i < insertCount; ++i)
278+
swapColumns(finalAdjStart + i, curAdjStart + i);
279+
280+
// Delete the now redundant zero columns.
281+
removeColumns(curAdjStart, insertCount);
282+
}
283+
243284
template <typename T>
244285
void Matrix<T>::addToRow(unsigned sourceRow, unsigned targetRow,
245286
const T &scale) {

mlir/unittests/Analysis/Presburger/MatrixTest.cpp

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,21 @@ TEST(MatrixTest, resize) {
194194
EXPECT_EQ(mat(row, col), row >= 3 || col >= 3 ? 0 : int(10 * row + col));
195195
}
196196

197+
template <typename T>
198+
static void checkMatEqual(const Matrix<T> m1, const Matrix<T> m2) {
199+
EXPECT_EQ(m1.getNumRows(), m2.getNumRows());
200+
EXPECT_EQ(m1.getNumColumns(), m2.getNumColumns());
201+
202+
for (unsigned row = 0, rows = m1.getNumRows(); row < rows; ++row)
203+
for (unsigned col = 0, cols = m1.getNumColumns(); col < cols; ++col)
204+
EXPECT_EQ(m1(row, col), m2(row, col));
205+
}
206+
197207
static void checkHermiteNormalForm(const IntMatrix &mat,
198208
const IntMatrix &hermiteForm) {
199209
auto [h, u] = mat.computeHermiteNormalForm();
200210

201-
for (unsigned row = 0; row < mat.getNumRows(); row++)
202-
for (unsigned col = 0; col < mat.getNumColumns(); col++)
203-
EXPECT_EQ(h(row, col), hermiteForm(row, col));
211+
checkMatEqual(h, hermiteForm);
204212
}
205213

206214
TEST(MatrixTest, computeHermiteNormalForm) {
@@ -428,4 +436,41 @@ TEST(MatrixTest, LLL) {
428436
mat.LLL(Fraction(3, 4));
429437

430438
checkReducedBasis(mat, Fraction(3, 4));
431-
}
439+
}
440+
441+
TEST(MatrixTest, moveColumns) {
442+
IntMatrix mat =
443+
makeIntMatrix(3, 4, {{0, 1, 2, 3}, {4, 5, 6, 7}, {8, 9, 4, 2}});
444+
445+
{
446+
IntMatrix movedMat =
447+
makeIntMatrix(3, 4, {{0, 3, 1, 2}, {4, 7, 5, 6}, {8, 2, 9, 4}});
448+
449+
movedMat.moveColumns(2, 2, 1);
450+
checkMatEqual(mat, movedMat);
451+
}
452+
453+
{
454+
IntMatrix movedMat =
455+
makeIntMatrix(3, 4, {{0, 3, 1, 2}, {4, 7, 5, 6}, {8, 2, 9, 4}});
456+
457+
movedMat.moveColumns(1, 1, 3);
458+
checkMatEqual(mat, movedMat);
459+
}
460+
461+
{
462+
IntMatrix movedMat =
463+
makeIntMatrix(3, 4, {{1, 2, 0, 3}, {5, 6, 4, 7}, {9, 4, 8, 2}});
464+
465+
movedMat.moveColumns(0, 2, 1);
466+
checkMatEqual(mat, movedMat);
467+
}
468+
469+
{
470+
IntMatrix movedMat =
471+
makeIntMatrix(3, 4, {{1, 0, 2, 3}, {5, 4, 6, 7}, {9, 8, 4, 2}});
472+
473+
movedMat.moveColumns(0, 1, 1);
474+
checkMatEqual(mat, movedMat);
475+
}
476+
}

0 commit comments

Comments
 (0)