diff --git a/llvm/include/llvm/Object/OffloadBundle.h b/llvm/include/llvm/Object/OffloadBundle.h index f4d5a1d878b8d..99f54ea4f28aa 100644 --- a/llvm/include/llvm/Object/OffloadBundle.h +++ b/llvm/include/llvm/Object/OffloadBundle.h @@ -32,29 +32,40 @@ namespace llvm { namespace object { +// CompressedOffloadBundle represents the format for the compressed offload +// bundles. +// +// The format is as follows: +// - Magic Number (4 bytes) - A constant "CCOB". +// - Version (2 bytes) +// - Compression Method (2 bytes) - Uses the values from +// llvm::compression::Format. +// - Total file size (4 bytes in V2, 8 bytes in V3). +// - Uncompressed Size (4 bytes in V1/V2, 8 bytes in V3). +// - Truncated MD5 Hash (8 bytes). +// - Compressed Data (variable length). class CompressedOffloadBundle { private: - static inline const size_t MagicSize = 4; - static inline const size_t VersionFieldSize = sizeof(uint16_t); - static inline const size_t MethodFieldSize = sizeof(uint16_t); - static inline const size_t FileSizeFieldSize = sizeof(uint32_t); - static inline const size_t UncompressedSizeFieldSize = sizeof(uint32_t); - static inline const size_t HashFieldSize = sizeof(uint64_t); - static inline const size_t V1HeaderSize = - MagicSize + VersionFieldSize + MethodFieldSize + - UncompressedSizeFieldSize + HashFieldSize; - static inline const size_t V2HeaderSize = - MagicSize + VersionFieldSize + FileSizeFieldSize + MethodFieldSize + - UncompressedSizeFieldSize + HashFieldSize; static inline const llvm::StringRef MagicNumber = "CCOB"; - static inline const uint16_t Version = 2; public: - LLVM_ABI static llvm::Expected> + struct CompressedBundleHeader { + unsigned Version; + llvm::compression::Format CompressionFormat; + std::optional FileSize; + size_t UncompressedFileSize; + uint64_t Hash; + + static llvm::Expected tryParse(llvm::StringRef); + }; + + static inline const uint16_t DefaultVersion = 2; + + static llvm::Expected> compress(llvm::compression::Params P, const llvm::MemoryBuffer &Input, - bool Verbose = false); - LLVM_ABI static llvm::Expected> - decompress(llvm::MemoryBufferRef &Input, bool Verbose = false); + uint16_t Version, bool Verbose = false); + static llvm::Expected> + decompress(const llvm::MemoryBuffer &Input, bool Verbose = false); }; /// Bundle entry in binary clang-offload-bundler format. @@ -62,12 +73,15 @@ struct OffloadBundleEntry { uint64_t Offset = 0u; uint64_t Size = 0u; uint64_t IDLength = 0u; - StringRef ID; - OffloadBundleEntry(uint64_t O, uint64_t S, uint64_t I, StringRef T) - : Offset(O), Size(S), IDLength(I), ID(T) {} + std::string ID; + OffloadBundleEntry(uint64_t O, uint64_t S, uint64_t I, std::string T) + : Offset(O), Size(S), IDLength(I) { + ID.reserve(T.size()); + ID = T; + } void dumpInfo(raw_ostream &OS) { OS << "Offset = " << Offset << ", Size = " << Size - << ", ID Length = " << IDLength << ", ID = " << ID; + << ", ID Length = " << IDLength << ", ID = " << ID << "\n"; } void dumpURI(raw_ostream &OS, StringRef FilePath) { OS << ID.data() << "\tfile://" << FilePath << "#offset=" << Offset @@ -82,15 +96,20 @@ class OffloadBundleFatBin { StringRef FileName; uint64_t NumberOfEntries; SmallVector Entries; + bool Decompressed; public: + std::unique_ptr DecompressedBuffer; + SmallVector getEntries() { return Entries; } uint64_t getSize() const { return Size; } StringRef getFileName() const { return FileName; } uint64_t getNumEntries() const { return NumberOfEntries; } + bool isDecompressed() const { return Decompressed; } LLVM_ABI static Expected> - create(MemoryBufferRef, uint64_t SectionOffset, StringRef FileName); + create(MemoryBufferRef, uint64_t SectionOffset, StringRef FileName, + bool Decompress = false); LLVM_ABI Error extractBundle(const ObjectFile &Source); LLVM_ABI Error dumpEntryToCodeObject(); @@ -106,9 +125,15 @@ class OffloadBundleFatBin { Entry.dumpURI(outs(), FileName); } - OffloadBundleFatBin(MemoryBufferRef Source, StringRef File) - : FileName(File), NumberOfEntries(0), - Entries(SmallVector()) {} + OffloadBundleFatBin(MemoryBufferRef Source, StringRef File, + bool Decompress = false) + : FileName(File), Decompressed(Decompress), NumberOfEntries(0), + Entries(SmallVector()) { + if (Decompress) { + DecompressedBuffer = + MemoryBuffer::getMemBufferCopy(Source.getBuffer(), File); + } + } }; enum UriTypeT { FILE_URI, MEMORY_URI }; @@ -191,6 +216,10 @@ LLVM_ABI Error extractOffloadBundleFatBinary( LLVM_ABI Error extractCodeObject(const ObjectFile &Source, int64_t Offset, int64_t Size, StringRef OutputFileName); +/// Extract code object memory from the given \p Source object file at \p Offset +/// and of \p Size, and copy into \p OutputFileName. +LLVM_ABI Error extractCodeObject(MemoryBufferRef Buffer, int64_t Offset, + int64_t Size, StringRef OutputFileName); /// Extracts an Offload Bundle Entry given by URI LLVM_ABI Error extractOffloadBundleByURI(StringRef URIstr); diff --git a/llvm/lib/Object/OffloadBundle.cpp b/llvm/lib/Object/OffloadBundle.cpp index 1e1042ce2bc21..57a8244a9b0e5 100644 --- a/llvm/lib/Object/OffloadBundle.cpp +++ b/llvm/lib/Object/OffloadBundle.cpp @@ -37,26 +37,63 @@ Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset, size_t Offset = 0; size_t NextbundleStart = 0; + StringRef Magic; + std::unique_ptr Buffer; // There could be multiple offloading bundles stored at this section. - while (NextbundleStart != StringRef::npos) { - std::unique_ptr Buffer = + while ((NextbundleStart != StringRef::npos) && + (Offset < Contents.getBuffer().size())) { + Buffer = MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "", /*RequiresNullTerminator=*/false); - // Create the FatBinBindle object. This will also create the Bundle Entry - // list info. - auto FatBundleOrErr = - OffloadBundleFatBin::create(*Buffer, SectionOffset + Offset, FileName); - if (!FatBundleOrErr) - return FatBundleOrErr.takeError(); - - // Add current Bundle to list. - Bundles.emplace_back(std::move(**FatBundleOrErr)); + if (identify_magic((*Buffer).getBuffer()) == + file_magic::offload_bundle_compressed) { + Magic = StringRef("CCOB"); + // decompress this bundle first. + NextbundleStart = (*Buffer).getBuffer().find(Magic, Magic.size()); + if (NextbundleStart == StringRef::npos) { + NextbundleStart = (*Buffer).getBuffer().size(); + } - // Find the next bundle by searching for the magic string - StringRef Str = Buffer->getBuffer(); - NextbundleStart = Str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24); + ErrorOr> CodeOrErr = + MemoryBuffer::getMemBuffer((*Buffer).getBuffer().take_front( + NextbundleStart /*- Magic.size()*/), + FileName, false); + if (std::error_code EC = CodeOrErr.getError()) + return createFileError(FileName, EC); + + Expected> DecompressedBufferOrErr = + CompressedOffloadBundle::decompress(**CodeOrErr, false); + if (!DecompressedBufferOrErr) + return createStringError( + inconvertibleErrorCode(), + "Failed to decompress input: " + + llvm::toString(DecompressedBufferOrErr.takeError())); + + auto FatBundleOrErr = OffloadBundleFatBin::create( + **DecompressedBufferOrErr, Offset, FileName, true); + if (!FatBundleOrErr) + return FatBundleOrErr.takeError(); + + // Add current Bundle to list. + Bundles.emplace_back(std::move(**FatBundleOrErr)); + + } else if (identify_magic((*Buffer).getBuffer()) == + file_magic::offload_bundle) { + // Create the FatBinBindle object. This will also create the Bundle Entry + // list info. + auto FatBundleOrErr = OffloadBundleFatBin::create( + *Buffer, SectionOffset + Offset, FileName); + if (!FatBundleOrErr) + return FatBundleOrErr.takeError(); + + // Add current Bundle to list. + Bundles.emplace_back(std::move(**FatBundleOrErr)); + + Magic = StringRef("__CLANG_OFFLOAD_BUNDLE__"); + NextbundleStart = (*Buffer).getBuffer().find(Magic, Magic.size()); + } if (NextbundleStart != StringRef::npos) Offset += NextbundleStart; @@ -102,7 +139,8 @@ Error OffloadBundleFatBin::readEntries(StringRef Buffer, return errorCodeToError(object_error::parse_failed); auto Entry = std::make_unique( - EntryOffset + SectionOffset, EntrySize, EntryIDSize, EntryID); + EntryOffset + SectionOffset, EntrySize, EntryIDSize, + std::move(EntryID.str())); Entries.push_back(*Entry); } @@ -112,18 +150,22 @@ Error OffloadBundleFatBin::readEntries(StringRef Buffer, Expected> OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset, - StringRef FileName) { + StringRef FileName, bool Decompress) { if (Buf.getBufferSize() < 24) return errorCodeToError(object_error::parse_failed); // Check for magic bytes. - if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle) + if ((identify_magic(Buf.getBuffer()) != file_magic::offload_bundle) && + (identify_magic(Buf.getBuffer()) != + file_magic::offload_bundle_compressed)) return errorCodeToError(object_error::parse_failed); - OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName); + OffloadBundleFatBin *TheBundle = + new OffloadBundleFatBin(Buf, FileName, Decompress); // Read the Bundle Entries - Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset); + Error Err = + TheBundle->readEntries(Buf.getBuffer(), Decompress ? 0 : SectionOffset); if (Err) return errorCodeToError(object_error::parse_failed); @@ -172,28 +214,9 @@ Error object::extractOffloadBundleFatBinary( "COFF object files not supported.\n"); MemoryBufferRef Contents(*Buffer, Obj.getFileName()); - - if (llvm::identify_magic(*Buffer) == - llvm::file_magic::offload_bundle_compressed) { - // Decompress the input if necessary. - Expected> DecompressedBufferOrErr = - CompressedOffloadBundle::decompress(Contents, false); - - if (!DecompressedBufferOrErr) - return createStringError( - inconvertibleErrorCode(), - "Failed to decompress input: " + - llvm::toString(DecompressedBufferOrErr.takeError())); - - MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr; - if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset, - Obj.getFileName(), Bundles)) - return Err; - } else { - if (Error Err = extractOffloadBundle(Contents, SectionOffset, - Obj.getFileName(), Bundles)) - return Err; - } + if (Error Err = extractOffloadBundle(Contents, SectionOffset, + Obj.getFileName(), Bundles)) + return Err; } } return Error::success(); @@ -221,6 +244,22 @@ Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset, return Error::success(); } +Error object::extractCodeObject(const MemoryBufferRef Buffer, int64_t Offset, + int64_t Size, StringRef OutputFileName) { + Expected> BufferOrErr = + FileOutputBuffer::create(OutputFileName, Size); + if (!BufferOrErr) + return BufferOrErr.takeError(); + + std::unique_ptr Buf = std::move(*BufferOrErr); + std::copy(Buffer.getBufferStart() + Offset, + Buffer.getBufferStart() + Offset + Size, Buf->getBufferStart()); + if (Error E = Buf->commit()) + return E; + + return Error::success(); +} + // given a file name, offset, and size, extract data into a code object file, // into file -offset-size.co Error object::extractOffloadBundleByURI(StringRef URIstr) { @@ -260,11 +299,233 @@ static std::string formatWithCommas(unsigned long long Value) { } llvm::Expected> -CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, +CompressedOffloadBundle::compress(llvm::compression::Params P, + const llvm::MemoryBuffer &Input, + uint16_t Version, bool Verbose) { + if (!llvm::compression::zstd::isAvailable() && + !llvm::compression::zlib::isAvailable()) + return createStringError(llvm::inconvertibleErrorCode(), + "Compression not supported"); + llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time", + OffloadBundlerTimerGroup); + if (Verbose) + HashTimer.startTimer(); + llvm::MD5 Hash; + llvm::MD5::MD5Result Result; + Hash.update(Input.getBuffer()); + Hash.final(Result); + uint64_t TruncatedHash = Result.low(); + if (Verbose) + HashTimer.stopTimer(); + + SmallVector CompressedBuffer; + auto BufferUint8 = llvm::ArrayRef( + reinterpret_cast(Input.getBuffer().data()), + Input.getBuffer().size()); + llvm::Timer CompressTimer("Compression Timer", "Compression time", + OffloadBundlerTimerGroup); + if (Verbose) + CompressTimer.startTimer(); + llvm::compression::compress(P, BufferUint8, CompressedBuffer); + if (Verbose) + CompressTimer.stopTimer(); + + uint16_t CompressionMethod = static_cast(P.format); + + // Store sizes in 64-bit variables first + uint64_t UncompressedSize64 = Input.getBuffer().size(); + uint64_t TotalFileSize64; + + // Calculate total file size based on version + if (Version == 2) { + // For V2, ensure the sizes don't exceed 32-bit limit + if (UncompressedSize64 > std::numeric_limits::max()) + return createStringError(llvm::inconvertibleErrorCode(), + "Uncompressed size exceeds version 2 limit"); + if ((MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) + + sizeof(CompressionMethod) + sizeof(uint32_t) + sizeof(TruncatedHash) + + CompressedBuffer.size()) > std::numeric_limits::max()) + return createStringError(llvm::inconvertibleErrorCode(), + "Total file size exceeds version 2 limit"); + + TotalFileSize64 = MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) + + sizeof(CompressionMethod) + sizeof(uint32_t) + + sizeof(TruncatedHash) + CompressedBuffer.size(); + } else { // Version 3 + TotalFileSize64 = MagicNumber.size() + sizeof(uint64_t) + sizeof(Version) + + sizeof(CompressionMethod) + sizeof(uint64_t) + + sizeof(TruncatedHash) + CompressedBuffer.size(); + } + + SmallVector FinalBuffer; + llvm::raw_svector_ostream OS(FinalBuffer); + OS << MagicNumber; + OS.write(reinterpret_cast(&Version), sizeof(Version)); + OS.write(reinterpret_cast(&CompressionMethod), + sizeof(CompressionMethod)); + + // Write size fields according to version + if (Version == 2) { + uint32_t TotalFileSize32 = static_cast(TotalFileSize64); + uint32_t UncompressedSize32 = static_cast(UncompressedSize64); + OS.write(reinterpret_cast(&TotalFileSize32), + sizeof(TotalFileSize32)); + OS.write(reinterpret_cast(&UncompressedSize32), + sizeof(UncompressedSize32)); + } else { // Version 3 + OS.write(reinterpret_cast(&TotalFileSize64), + sizeof(TotalFileSize64)); + OS.write(reinterpret_cast(&UncompressedSize64), + sizeof(UncompressedSize64)); + } + + OS.write(reinterpret_cast(&TruncatedHash), + sizeof(TruncatedHash)); + OS.write(reinterpret_cast(CompressedBuffer.data()), + CompressedBuffer.size()); + + if (Verbose) { + auto MethodUsed = + P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib"; + double CompressionRate = + static_cast(UncompressedSize64) / CompressedBuffer.size(); + double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime(); + double CompressionSpeedMBs = + (UncompressedSize64 / (1024.0 * 1024.0)) / CompressionTimeSeconds; + llvm::errs() << "Compressed bundle format version: " << Version << "\n" + << "Total file size (including headers): " + << formatWithCommas(TotalFileSize64) << " bytes\n" + << "Compression method used: " << MethodUsed << "\n" + << "Compression level: " << P.level << "\n" + << "Binary size before compression: " + << formatWithCommas(UncompressedSize64) << " bytes\n" + << "Binary size after compression: " + << formatWithCommas(CompressedBuffer.size()) << " bytes\n" + << "Compression rate: " + << llvm::format("%.2lf", CompressionRate) << "\n" + << "Compression ratio: " + << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n" + << "Compression speed: " + << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n" + << "Truncated MD5 hash: " + << llvm::format_hex(TruncatedHash, 16) << "\n"; + } + + return llvm::MemoryBuffer::getMemBufferCopy( + llvm::StringRef(FinalBuffer.data(), FinalBuffer.size())); +} + +// Use packed structs to avoid padding, such that the structs map the serialized +// format. +LLVM_PACKED_START +union RawCompressedBundleHeader { + struct CommonFields { + uint32_t Magic; + uint16_t Version; + uint16_t Method; + }; + + struct V1Header { + CommonFields Common; + uint32_t UncompressedFileSize; + uint64_t Hash; + }; + + struct V2Header { + CommonFields Common; + uint32_t FileSize; + uint32_t UncompressedFileSize; + uint64_t Hash; + }; + + struct V3Header { + CommonFields Common; + uint64_t FileSize; + uint64_t UncompressedFileSize; + uint64_t Hash; + }; + + CommonFields Common; + V1Header V1; + V2Header V2; + V3Header V3; +}; +LLVM_PACKED_END + +// Helper method to get header size based on version +static size_t getHeaderSize(uint16_t Version) { + switch (Version) { + case 1: + return sizeof(RawCompressedBundleHeader::V1Header); + case 2: + return sizeof(RawCompressedBundleHeader::V2Header); + case 3: + return sizeof(RawCompressedBundleHeader::V3Header); + default: + llvm_unreachable("Unsupported version"); + } +} + +Expected +CompressedOffloadBundle::CompressedBundleHeader::tryParse(StringRef Blob) { + assert(Blob.size() >= sizeof(RawCompressedBundleHeader::CommonFields)); + assert(llvm::identify_magic(Blob) == + llvm::file_magic::offload_bundle_compressed); + + RawCompressedBundleHeader Header; + memcpy(&Header, Blob.data(), std::min(Blob.size(), sizeof(Header))); + + CompressedBundleHeader Normalized; + Normalized.Version = Header.Common.Version; + + size_t RequiredSize = getHeaderSize(Normalized.Version); + + if (Blob.size() < RequiredSize) + return createStringError(inconvertibleErrorCode(), + "Compressed bundle header size too small"); + + switch (Normalized.Version) { + case 1: + Normalized.UncompressedFileSize = Header.V1.UncompressedFileSize; + Normalized.Hash = Header.V1.Hash; + break; + case 2: + Normalized.FileSize = Header.V2.FileSize; + Normalized.UncompressedFileSize = Header.V2.UncompressedFileSize; + Normalized.Hash = Header.V2.Hash; + break; + case 3: + Normalized.FileSize = Header.V3.FileSize; + Normalized.UncompressedFileSize = Header.V3.UncompressedFileSize; + Normalized.Hash = Header.V3.Hash; + break; + default: + return createStringError(inconvertibleErrorCode(), + "Unknown compressed bundle version"); + } + + // Determine compression format + switch (Header.Common.Method) { + case static_cast(compression::Format::Zlib): + case static_cast(compression::Format::Zstd): + Normalized.CompressionFormat = + static_cast(Header.Common.Method); + break; + default: + return createStringError(inconvertibleErrorCode(), + "Unknown compressing method"); + } + + return Normalized; +} + +llvm::Expected> +CompressedOffloadBundle::decompress(const llvm::MemoryBuffer &Input, bool Verbose) { StringRef Blob = Input.getBuffer(); - if (Blob.size() < V1HeaderSize) + // Check minimum header size (using V1 as it's the smallest) + if (Blob.size() < sizeof(RawCompressedBundleHeader::CommonFields)) return llvm::MemoryBuffer::getMemBufferCopy(Blob); if (llvm::identify_magic(Blob) != @@ -274,43 +535,20 @@ CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, return llvm::MemoryBuffer::getMemBufferCopy(Blob); } - size_t CurrentOffset = MagicSize; - - uint16_t ThisVersion; - memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t)); - CurrentOffset += VersionFieldSize; + Expected HeaderOrErr = + CompressedBundleHeader::tryParse(Blob); + if (!HeaderOrErr) + return HeaderOrErr.takeError(); - uint16_t CompressionMethod; - memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t)); - CurrentOffset += MethodFieldSize; + const CompressedBundleHeader &Normalized = *HeaderOrErr; + unsigned ThisVersion = Normalized.Version; + size_t HeaderSize = getHeaderSize(ThisVersion); - uint32_t TotalFileSize; - if (ThisVersion >= 2) { - if (Blob.size() < V2HeaderSize) - return createStringError(inconvertibleErrorCode(), - "Compressed bundle header size too small"); - memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t)); - CurrentOffset += FileSizeFieldSize; - } + llvm::compression::Format CompressionFormat = Normalized.CompressionFormat; - uint32_t UncompressedSize; - memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t)); - CurrentOffset += UncompressedSizeFieldSize; - - uint64_t StoredHash; - memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t)); - CurrentOffset += HashFieldSize; - - llvm::compression::Format CompressionFormat; - if (CompressionMethod == - static_cast(llvm::compression::Format::Zlib)) - CompressionFormat = llvm::compression::Format::Zlib; - else if (CompressionMethod == - static_cast(llvm::compression::Format::Zstd)) - CompressionFormat = llvm::compression::Format::Zstd; - else - return createStringError(inconvertibleErrorCode(), - "Unknown compressing method"); + size_t TotalFileSize = Normalized.FileSize.value_or(0); + size_t UncompressedSize = Normalized.UncompressedFileSize; + auto StoredHash = Normalized.Hash; llvm::Timer DecompressTimer("Decompression Timer", "Decompression time", OffloadBundlerTimerGroup); @@ -318,7 +556,9 @@ CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, DecompressTimer.startTimer(); SmallVector DecompressedData; - StringRef CompressedData = Blob.substr(CurrentOffset); + StringRef CompressedData = + Blob.substr(HeaderSize, TotalFileSize - HeaderSize); + if (llvm::Error DecompressionError = llvm::compression::decompress( CompressionFormat, llvm::arrayRefFromStringRef(CompressedData), DecompressedData, UncompressedSize)) @@ -332,7 +572,7 @@ CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, double DecompressionTimeSeconds = DecompressTimer.getTotalTime().getWallTime(); - // Recalculate MD5 hash for integrity check. + // Recalculate MD5 hash for integrity check llvm::Timer HashRecalcTimer("Hash Recalculation Timer", "Hash recalculation time", OffloadBundlerTimerGroup); @@ -378,90 +618,3 @@ CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input, return llvm::MemoryBuffer::getMemBufferCopy( llvm::toStringRef(DecompressedData)); } - -llvm::Expected> -CompressedOffloadBundle::compress(llvm::compression::Params P, - const llvm::MemoryBuffer &Input, - bool Verbose) { - if (!llvm::compression::zstd::isAvailable() && - !llvm::compression::zlib::isAvailable()) - return createStringError(llvm::inconvertibleErrorCode(), - "Compression not supported"); - - llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time", - OffloadBundlerTimerGroup); - if (Verbose) - HashTimer.startTimer(); - llvm::MD5 Hash; - llvm::MD5::MD5Result Result; - Hash.update(Input.getBuffer()); - Hash.final(Result); - uint64_t TruncatedHash = Result.low(); - if (Verbose) - HashTimer.stopTimer(); - - SmallVector CompressedBuffer; - auto BufferUint8 = llvm::ArrayRef( - reinterpret_cast(Input.getBuffer().data()), - Input.getBuffer().size()); - - llvm::Timer CompressTimer("Compression Timer", "Compression time", - OffloadBundlerTimerGroup); - if (Verbose) - CompressTimer.startTimer(); - llvm::compression::compress(P, BufferUint8, CompressedBuffer); - if (Verbose) - CompressTimer.stopTimer(); - - uint16_t CompressionMethod = static_cast(P.format); - uint32_t UncompressedSize = Input.getBuffer().size(); - uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) + - sizeof(Version) + sizeof(CompressionMethod) + - sizeof(UncompressedSize) + sizeof(TruncatedHash) + - CompressedBuffer.size(); - - SmallVector FinalBuffer; - llvm::raw_svector_ostream OS(FinalBuffer); - OS << MagicNumber; - OS.write(reinterpret_cast(&Version), sizeof(Version)); - OS.write(reinterpret_cast(&CompressionMethod), - sizeof(CompressionMethod)); - OS.write(reinterpret_cast(&TotalFileSize), - sizeof(TotalFileSize)); - OS.write(reinterpret_cast(&UncompressedSize), - sizeof(UncompressedSize)); - OS.write(reinterpret_cast(&TruncatedHash), - sizeof(TruncatedHash)); - OS.write(reinterpret_cast(CompressedBuffer.data()), - CompressedBuffer.size()); - - if (Verbose) { - auto MethodUsed = - P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib"; - double CompressionRate = - static_cast(UncompressedSize) / CompressedBuffer.size(); - double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime(); - double CompressionSpeedMBs = - (UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds; - - llvm::errs() << "Compressed bundle format version: " << Version << "\n" - << "Total file size (including headers): " - << formatWithCommas(TotalFileSize) << " bytes\n" - << "Compression method used: " << MethodUsed << "\n" - << "Compression level: " << P.level << "\n" - << "Binary size before compression: " - << formatWithCommas(UncompressedSize) << " bytes\n" - << "Binary size after compression: " - << formatWithCommas(CompressedBuffer.size()) << " bytes\n" - << "Compression rate: " - << llvm::format("%.2lf", CompressionRate) << "\n" - << "Compression ratio: " - << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n" - << "Compression speed: " - << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n" - << "Truncated MD5 hash: " - << llvm::format_hex(TruncatedHash, 16) << "\n"; - } - return llvm::MemoryBuffer::getMemBufferCopy( - llvm::StringRef(FinalBuffer.data(), FinalBuffer.size())); -} diff --git a/llvm/tools/llvm-objdump/OffloadDump.cpp b/llvm/tools/llvm-objdump/OffloadDump.cpp index 8a0deb35ba151..8a34d0f60a2ac 100644 --- a/llvm/tools/llvm-objdump/OffloadDump.cpp +++ b/llvm/tools/llvm-objdump/OffloadDump.cpp @@ -87,21 +87,30 @@ void llvm::dumpOffloadBundleFatBinary(const ObjectFile &O, StringRef ArchName) { if (Error Err = llvm::object::extractOffloadBundleFatBinary(O, FoundBundles)) reportError(O.getFileName(), "while extracting offload FatBin bundles: " + toString(std::move(Err))); - for (const auto &[BundleNum, Bundle] : llvm::enumerate(FoundBundles)) { for (OffloadBundleEntry &Entry : Bundle.getEntries()) { - if (!ArchName.empty() && !Entry.ID.contains(ArchName)) + if (!ArchName.empty() && (Entry.ID.find(ArchName) != std::string::npos)) continue; // create file name for this object file: .. - std::string str = Bundle.getFileName().str() + "." + itostr(BundleNum) + - "." + Entry.ID.str(); - if (Error Err = object::extractCodeObject(O, Entry.Offset, Entry.Size, - StringRef(str))) - reportError(O.getFileName(), - "while extracting offload Bundle Entries: " + - toString(std::move(Err))); + std::string str = + Bundle.getFileName().str() + "." + itostr(BundleNum) + "." + Entry.ID; + + if (Bundle.isDecompressed()) { + if (Error Err = object::extractCodeObject( + Bundle.DecompressedBuffer->getMemBufferRef(), Entry.Offset, + Entry.Size, StringRef(str))) + reportError(O.getFileName(), + "while extracting offload Bundle Entries: " + + toString(std::move(Err))); + } else { + if (Error Err = object::extractCodeObject(O, Entry.Offset, Entry.Size, + StringRef(str))) + reportError(O.getFileName(), + "while extracting offload Bundle Entries: " + + toString(std::move(Err))); + } outs() << "Extracting offload bundle: " << str << "\n"; } }