Skip to content

Commit cfd9093

Browse files
committed
Fix MLIR bytecode loading of resources
The bytecode reader didn't handle properly the case where resource names conflicted and were renamed, leading to orphan handles in the IR as well as overwriting the exiting resources. Differential Revision: https://reviews.llvm.org/D151408
1 parent e84589c commit cfd9093

File tree

5 files changed

+107
-1
lines changed

5 files changed

+107
-1
lines changed

mlir/lib/AsmParser/Parser.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "llvm/ADT/StringSet.h"
2626
#include "llvm/ADT/bit.h"
2727
#include "llvm/Support/Endian.h"
28+
#include "llvm/Support/MathExtras.h"
2829
#include "llvm/Support/PrettyStackTrace.h"
2930
#include "llvm/Support/SourceMgr.h"
3031
#include <algorithm>
@@ -2482,6 +2483,13 @@ class ParsedResourceEntry : public AsmParsedResourceEntry {
24822483
}
24832484
llvm::support::ulittle32_t align;
24842485
memcpy(&align, blobData->data(), sizeof(uint32_t));
2486+
if (align && !llvm::isPowerOf2_32(align)) {
2487+
return p.emitError(value.getLoc(),
2488+
"expected hex string blob for key '" + key +
2489+
"' to encode alignment in first 4 bytes, but got "
2490+
"non-power-of-2 value: " +
2491+
Twine(align));
2492+
}
24852493

24862494
// Get the data portion of the blob.
24872495
StringRef data = StringRef(*blobData).drop_front(sizeof(uint32_t));

mlir/lib/Bytecode/Reader/BytecodeReader.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "llvm/ADT/ScopeExit.h"
2525
#include "llvm/ADT/SmallString.h"
2626
#include "llvm/ADT/StringExtras.h"
27+
#include "llvm/ADT/StringRef.h"
2728
#include "llvm/Support/MemoryBufferRef.h"
2829
#include "llvm/Support/SaveAndRestore.h"
2930
#include "llvm/Support/SourceMgr.h"
@@ -516,6 +517,7 @@ class ResourceSectionReader {
516517
private:
517518
/// The table of dialect resources within the bytecode file.
518519
SmallVector<AsmDialectResourceHandle> dialectResources;
520+
llvm::StringMap<std::string> dialectResourceHandleRenamingMap;
519521
};
520522

521523
class ParsedResourceEntry : public AsmParsedResourceEntry {
@@ -604,6 +606,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
604606
EncodingReader &offsetReader, EncodingReader &resourceReader,
605607
StringSectionReader &stringReader, T *handler,
606608
const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef,
609+
function_ref<StringRef(StringRef)> remapKey = {},
607610
function_ref<LogicalResult(StringRef)> processKeyFn = {}) {
608611
uint64_t numResources;
609612
if (failed(offsetReader.parseVarInt(numResources)))
@@ -635,6 +638,7 @@ parseResourceGroup(Location fileLoc, bool allowEmpty,
635638

636639
// Otherwise, parse the resource value.
637640
EncodingReader entryReader(data, fileLoc);
641+
key = remapKey(key);
638642
ParsedResourceEntry entry(key, kind, entryReader, stringReader,
639643
bufferOwnerRef);
640644
if (failed(handler->parseResource(entry)))
@@ -665,8 +669,16 @@ LogicalResult ResourceSectionReader::initialize(
665669
// provides most of the arguments.
666670
auto parseGroup = [&](auto *handler, bool allowEmpty = false,
667671
function_ref<LogicalResult(StringRef)> keyFn = {}) {
672+
auto resolveKey = [&](StringRef key) -> StringRef {
673+
auto it = dialectResourceHandleRenamingMap.find(key);
674+
if (it == dialectResourceHandleRenamingMap.end())
675+
return "";
676+
return it->second;
677+
};
678+
668679
return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader,
669-
stringReader, handler, bufferOwnerRef, keyFn);
680+
stringReader, handler, bufferOwnerRef, resolveKey,
681+
keyFn);
670682
};
671683

672684
// Read the external resources from the bytecode.
@@ -714,6 +726,7 @@ LogicalResult ResourceSectionReader::initialize(
714726
<< "unknown 'resource' key '" << key << "' for dialect '"
715727
<< dialect->name << "'";
716728
}
729+
dialectResourceHandleRenamingMap[key] = handler->getResourceKey(*handle);
717730
dialectResources.push_back(*handle);
718731
return success();
719732
};
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
//===- AdaptorTest.cpp - Adaptor unit tests -------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Bytecode/BytecodeReader.h"
10+
#include "mlir/Bytecode/BytecodeWriter.h"
11+
#include "mlir/IR/AsmState.h"
12+
#include "mlir/IR/BuiltinAttributes.h"
13+
#include "mlir/IR/OpImplementation.h"
14+
#include "mlir/IR/OwningOpRef.h"
15+
#include "mlir/Parser/Parser.h"
16+
17+
#include "llvm/ADT/StringRef.h"
18+
#include "gmock/gmock.h"
19+
#include "gtest/gtest.h"
20+
21+
using namespace llvm;
22+
using namespace mlir;
23+
24+
using testing::ElementsAre;
25+
26+
StringLiteral IRWithResources = R"(
27+
module @TestDialectResources attributes {
28+
bytecode.test = dense_resource<resource> : tensor<4xi32>
29+
} {}
30+
{-#
31+
dialect_resources: {
32+
builtin: {
33+
resource: "0x1000000001000000020000000300000004000000"
34+
}
35+
}
36+
#-}
37+
)";
38+
39+
TEST(Bytecode, MultiModuleWithResource) {
40+
MLIRContext context;
41+
Builder builder(&context);
42+
ParserConfig parseConfig(&context);
43+
OwningOpRef<Operation *> module =
44+
parseSourceString<Operation *>(IRWithResources, parseConfig);
45+
ASSERT_TRUE(module);
46+
47+
// Write the module to bytecode
48+
std::string buffer;
49+
llvm::raw_string_ostream ostream(buffer);
50+
ASSERT_TRUE(succeeded(writeBytecodeToFile(module.get(), ostream)));
51+
52+
// Parse it back
53+
OwningOpRef<Operation *> roundTripModule =
54+
parseSourceString<Operation *>(ostream.str(), parseConfig);
55+
ASSERT_TRUE(roundTripModule);
56+
57+
// Try to see if we have a valid resource in the parsed module.
58+
auto checkResourceAttribute = [&](Operation *op) {
59+
Attribute attr = roundTripModule->getAttr("bytecode.test");
60+
ASSERT_TRUE(attr);
61+
auto denseResourceAttr = dyn_cast<DenseI32ResourceElementsAttr>(attr);
62+
ASSERT_TRUE(denseResourceAttr);
63+
std::optional<ArrayRef<int32_t>> attrData =
64+
denseResourceAttr.tryGetAsArrayRef();
65+
ASSERT_TRUE(attrData.has_value());
66+
ASSERT_EQ(attrData->size(), static_cast<size_t>(4));
67+
EXPECT_EQ((*attrData)[0], 1);
68+
EXPECT_EQ((*attrData)[1], 2);
69+
EXPECT_EQ((*attrData)[2], 3);
70+
EXPECT_EQ((*attrData)[3], 4);
71+
};
72+
73+
checkResourceAttribute(*module);
74+
checkResourceAttribute(*roundTripModule);
75+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
add_mlir_unittest(MLIRBytecodeTests
2+
BytecodeTest.cpp
3+
)
4+
target_link_libraries(MLIRBytecodeTests
5+
PRIVATE
6+
MLIRBytecodeReader
7+
MLIRBytecodeWriter
8+
MLIRParser
9+
)

mlir/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ function(add_mlir_unittest test_dirname)
55
endfunction()
66

77
add_subdirectory(Analysis)
8+
add_subdirectory(Bytecode)
89
add_subdirectory(Conversion)
910
add_subdirectory(Debug)
1011
add_subdirectory(Dialect)

0 commit comments

Comments
 (0)