Skip to content

fix compress/decompress in LLVM Offloading API #147616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

david-salinas
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jul 8, 2025

@llvm/pr-subscribers-llvm-binary-utilities

Author: David Salinas (david-salinas)

Changes

Patch is 31.04 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/147616.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Object/OffloadBundle.h (+54-25)
  • (modified) llvm/lib/Object/OffloadBundle.cpp (+319-166)
  • (modified) llvm/tools/llvm-objdump/OffloadDump.cpp (+18-9)
diff --git a/llvm/include/llvm/Object/OffloadBundle.h b/llvm/include/llvm/Object/OffloadBundle.h
index f4d5a1d878b8d..99f54ea4f28aa 100644
--- a/llvm/include/llvm/Object/OffloadBundle.h
+++ b/llvm/include/llvm/Object/OffloadBundle.h
@@ -32,29 +32,40 @@ namespace llvm {
 
 namespace object {
 
+// CompressedOffloadBundle represents the format for the compressed offload
+// bundles.
+//
+// The format is as follows:
+// - Magic Number (4 bytes) - A constant "CCOB".
+// - Version (2 bytes)
+// - Compression Method (2 bytes) - Uses the values from
+// llvm::compression::Format.
+// - Total file size (4 bytes in V2, 8 bytes in V3).
+// - Uncompressed Size (4 bytes in V1/V2, 8 bytes in V3).
+// - Truncated MD5 Hash (8 bytes).
+// - Compressed Data (variable length).
 class CompressedOffloadBundle {
 private:
-  static inline const size_t MagicSize = 4;
-  static inline const size_t VersionFieldSize = sizeof(uint16_t);
-  static inline const size_t MethodFieldSize = sizeof(uint16_t);
-  static inline const size_t FileSizeFieldSize = sizeof(uint32_t);
-  static inline const size_t UncompressedSizeFieldSize = sizeof(uint32_t);
-  static inline const size_t HashFieldSize = sizeof(uint64_t);
-  static inline const size_t V1HeaderSize =
-      MagicSize + VersionFieldSize + MethodFieldSize +
-      UncompressedSizeFieldSize + HashFieldSize;
-  static inline const size_t V2HeaderSize =
-      MagicSize + VersionFieldSize + FileSizeFieldSize + MethodFieldSize +
-      UncompressedSizeFieldSize + HashFieldSize;
   static inline const llvm::StringRef MagicNumber = "CCOB";
-  static inline const uint16_t Version = 2;
 
 public:
-  LLVM_ABI static llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
+  struct CompressedBundleHeader {
+    unsigned Version;
+    llvm::compression::Format CompressionFormat;
+    std::optional<size_t> FileSize;
+    size_t UncompressedFileSize;
+    uint64_t Hash;
+
+    static llvm::Expected<CompressedBundleHeader> tryParse(llvm::StringRef);
+  };
+
+  static inline const uint16_t DefaultVersion = 2;
+
+  static llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
   compress(llvm::compression::Params P, const llvm::MemoryBuffer &Input,
-           bool Verbose = false);
-  LLVM_ABI static llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
-  decompress(llvm::MemoryBufferRef &Input, bool Verbose = false);
+           uint16_t Version, bool Verbose = false);
+  static llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
+  decompress(const llvm::MemoryBuffer &Input, bool Verbose = false);
 };
 
 /// Bundle entry in binary clang-offload-bundler format.
@@ -62,12 +73,15 @@ struct OffloadBundleEntry {
   uint64_t Offset = 0u;
   uint64_t Size = 0u;
   uint64_t IDLength = 0u;
-  StringRef ID;
-  OffloadBundleEntry(uint64_t O, uint64_t S, uint64_t I, StringRef T)
-      : Offset(O), Size(S), IDLength(I), ID(T) {}
+  std::string ID;
+  OffloadBundleEntry(uint64_t O, uint64_t S, uint64_t I, std::string T)
+      : Offset(O), Size(S), IDLength(I) {
+    ID.reserve(T.size());
+    ID = T;
+  }
   void dumpInfo(raw_ostream &OS) {
     OS << "Offset = " << Offset << ", Size = " << Size
-       << ", ID Length = " << IDLength << ", ID = " << ID;
+       << ", ID Length = " << IDLength << ", ID = " << ID << "\n";
   }
   void dumpURI(raw_ostream &OS, StringRef FilePath) {
     OS << ID.data() << "\tfile://" << FilePath << "#offset=" << Offset
@@ -82,15 +96,20 @@ class OffloadBundleFatBin {
   StringRef FileName;
   uint64_t NumberOfEntries;
   SmallVector<OffloadBundleEntry> Entries;
+  bool Decompressed;
 
 public:
+  std::unique_ptr<MemoryBuffer> DecompressedBuffer;
+
   SmallVector<OffloadBundleEntry> getEntries() { return Entries; }
   uint64_t getSize() const { return Size; }
   StringRef getFileName() const { return FileName; }
   uint64_t getNumEntries() const { return NumberOfEntries; }
+  bool isDecompressed() const { return Decompressed; }
 
   LLVM_ABI static Expected<std::unique_ptr<OffloadBundleFatBin>>
-  create(MemoryBufferRef, uint64_t SectionOffset, StringRef FileName);
+  create(MemoryBufferRef, uint64_t SectionOffset, StringRef FileName,
+         bool Decompress = false);
   LLVM_ABI Error extractBundle(const ObjectFile &Source);
 
   LLVM_ABI Error dumpEntryToCodeObject();
@@ -106,9 +125,15 @@ class OffloadBundleFatBin {
       Entry.dumpURI(outs(), FileName);
   }
 
-  OffloadBundleFatBin(MemoryBufferRef Source, StringRef File)
-      : FileName(File), NumberOfEntries(0),
-        Entries(SmallVector<OffloadBundleEntry>()) {}
+  OffloadBundleFatBin(MemoryBufferRef Source, StringRef File,
+                      bool Decompress = false)
+      : FileName(File), Decompressed(Decompress), NumberOfEntries(0),
+        Entries(SmallVector<OffloadBundleEntry>()) {
+    if (Decompress) {
+      DecompressedBuffer =
+          MemoryBuffer::getMemBufferCopy(Source.getBuffer(), File);
+    }
+  }
 };
 
 enum UriTypeT { FILE_URI, MEMORY_URI };
@@ -191,6 +216,10 @@ LLVM_ABI Error extractOffloadBundleFatBinary(
 LLVM_ABI Error extractCodeObject(const ObjectFile &Source, int64_t Offset,
                                  int64_t Size, StringRef OutputFileName);
 
+/// Extract code object memory from the given \p Source object file at \p Offset
+/// and of \p Size, and copy into \p OutputFileName.
+LLVM_ABI Error extractCodeObject(MemoryBufferRef Buffer, int64_t Offset,
+                                 int64_t Size, StringRef OutputFileName);
 /// Extracts an Offload Bundle Entry given by URI
 LLVM_ABI Error extractOffloadBundleByURI(StringRef URIstr);
 
diff --git a/llvm/lib/Object/OffloadBundle.cpp b/llvm/lib/Object/OffloadBundle.cpp
index 1e1042ce2bc21..57a8244a9b0e5 100644
--- a/llvm/lib/Object/OffloadBundle.cpp
+++ b/llvm/lib/Object/OffloadBundle.cpp
@@ -37,26 +37,63 @@ Error extractOffloadBundle(MemoryBufferRef Contents, uint64_t SectionOffset,
 
   size_t Offset = 0;
   size_t NextbundleStart = 0;
+  StringRef Magic;
+  std::unique_ptr<MemoryBuffer> Buffer;
 
   // There could be multiple offloading bundles stored at this section.
-  while (NextbundleStart != StringRef::npos) {
-    std::unique_ptr<MemoryBuffer> Buffer =
+  while ((NextbundleStart != StringRef::npos) &&
+         (Offset < Contents.getBuffer().size())) {
+    Buffer =
         MemoryBuffer::getMemBuffer(Contents.getBuffer().drop_front(Offset), "",
                                    /*RequiresNullTerminator=*/false);
 
-    // Create the FatBinBindle object. This will also create the Bundle Entry
-    // list info.
-    auto FatBundleOrErr =
-        OffloadBundleFatBin::create(*Buffer, SectionOffset + Offset, FileName);
-    if (!FatBundleOrErr)
-      return FatBundleOrErr.takeError();
-
-    // Add current Bundle to list.
-    Bundles.emplace_back(std::move(**FatBundleOrErr));
+    if (identify_magic((*Buffer).getBuffer()) ==
+        file_magic::offload_bundle_compressed) {
+      Magic = StringRef("CCOB");
+      // decompress this bundle first.
+      NextbundleStart = (*Buffer).getBuffer().find(Magic, Magic.size());
+      if (NextbundleStart == StringRef::npos) {
+        NextbundleStart = (*Buffer).getBuffer().size();
+      }
 
-    // Find the next bundle by searching for the magic string
-    StringRef Str = Buffer->getBuffer();
-    NextbundleStart = Str.find(StringRef("__CLANG_OFFLOAD_BUNDLE__"), 24);
+      ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
+          MemoryBuffer::getMemBuffer((*Buffer).getBuffer().take_front(
+                                         NextbundleStart /*- Magic.size()*/),
+                                     FileName, false);
+      if (std::error_code EC = CodeOrErr.getError())
+        return createFileError(FileName, EC);
+
+      Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
+          CompressedOffloadBundle::decompress(**CodeOrErr, false);
+      if (!DecompressedBufferOrErr)
+        return createStringError(
+            inconvertibleErrorCode(),
+            "Failed to decompress input: " +
+                llvm::toString(DecompressedBufferOrErr.takeError()));
+
+      auto FatBundleOrErr = OffloadBundleFatBin::create(
+          **DecompressedBufferOrErr, Offset, FileName, true);
+      if (!FatBundleOrErr)
+        return FatBundleOrErr.takeError();
+
+      // Add current Bundle to list.
+      Bundles.emplace_back(std::move(**FatBundleOrErr));
+
+    } else if (identify_magic((*Buffer).getBuffer()) ==
+               file_magic::offload_bundle) {
+      // Create the FatBinBindle object. This will also create the Bundle Entry
+      // list info.
+      auto FatBundleOrErr = OffloadBundleFatBin::create(
+          *Buffer, SectionOffset + Offset, FileName);
+      if (!FatBundleOrErr)
+        return FatBundleOrErr.takeError();
+
+      // Add current Bundle to list.
+      Bundles.emplace_back(std::move(**FatBundleOrErr));
+
+      Magic = StringRef("__CLANG_OFFLOAD_BUNDLE__");
+      NextbundleStart = (*Buffer).getBuffer().find(Magic, Magic.size());
+    }
 
     if (NextbundleStart != StringRef::npos)
       Offset += NextbundleStart;
@@ -102,7 +139,8 @@ Error OffloadBundleFatBin::readEntries(StringRef Buffer,
       return errorCodeToError(object_error::parse_failed);
 
     auto Entry = std::make_unique<OffloadBundleEntry>(
-        EntryOffset + SectionOffset, EntrySize, EntryIDSize, EntryID);
+        EntryOffset + SectionOffset, EntrySize, EntryIDSize,
+        std::move(EntryID.str()));
 
     Entries.push_back(*Entry);
   }
@@ -112,18 +150,22 @@ Error OffloadBundleFatBin::readEntries(StringRef Buffer,
 
 Expected<std::unique_ptr<OffloadBundleFatBin>>
 OffloadBundleFatBin::create(MemoryBufferRef Buf, uint64_t SectionOffset,
-                            StringRef FileName) {
+                            StringRef FileName, bool Decompress) {
   if (Buf.getBufferSize() < 24)
     return errorCodeToError(object_error::parse_failed);
 
   // Check for magic bytes.
-  if (identify_magic(Buf.getBuffer()) != file_magic::offload_bundle)
+  if ((identify_magic(Buf.getBuffer()) != file_magic::offload_bundle) &&
+      (identify_magic(Buf.getBuffer()) !=
+       file_magic::offload_bundle_compressed))
     return errorCodeToError(object_error::parse_failed);
 
-  OffloadBundleFatBin *TheBundle = new OffloadBundleFatBin(Buf, FileName);
+  OffloadBundleFatBin *TheBundle =
+      new OffloadBundleFatBin(Buf, FileName, Decompress);
 
   // Read the Bundle Entries
-  Error Err = TheBundle->readEntries(Buf.getBuffer(), SectionOffset);
+  Error Err =
+      TheBundle->readEntries(Buf.getBuffer(), Decompress ? 0 : SectionOffset);
   if (Err)
     return errorCodeToError(object_error::parse_failed);
 
@@ -172,28 +214,9 @@ Error object::extractOffloadBundleFatBinary(
                                  "COFF object files not supported.\n");
 
       MemoryBufferRef Contents(*Buffer, Obj.getFileName());
-
-      if (llvm::identify_magic(*Buffer) ==
-          llvm::file_magic::offload_bundle_compressed) {
-        // Decompress the input if necessary.
-        Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
-            CompressedOffloadBundle::decompress(Contents, false);
-
-        if (!DecompressedBufferOrErr)
-          return createStringError(
-              inconvertibleErrorCode(),
-              "Failed to decompress input: " +
-                  llvm::toString(DecompressedBufferOrErr.takeError()));
-
-        MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;
-        if (Error Err = extractOffloadBundle(DecompressedInput, SectionOffset,
-                                             Obj.getFileName(), Bundles))
-          return Err;
-      } else {
-        if (Error Err = extractOffloadBundle(Contents, SectionOffset,
-                                             Obj.getFileName(), Bundles))
-          return Err;
-      }
+      if (Error Err = extractOffloadBundle(Contents, SectionOffset,
+                                           Obj.getFileName(), Bundles))
+        return Err;
     }
   }
   return Error::success();
@@ -221,6 +244,22 @@ Error object::extractCodeObject(const ObjectFile &Source, int64_t Offset,
   return Error::success();
 }
 
+Error object::extractCodeObject(const MemoryBufferRef Buffer, int64_t Offset,
+                                int64_t Size, StringRef OutputFileName) {
+  Expected<std::unique_ptr<FileOutputBuffer>> BufferOrErr =
+      FileOutputBuffer::create(OutputFileName, Size);
+  if (!BufferOrErr)
+    return BufferOrErr.takeError();
+
+  std::unique_ptr<FileOutputBuffer> Buf = std::move(*BufferOrErr);
+  std::copy(Buffer.getBufferStart() + Offset,
+            Buffer.getBufferStart() + Offset + Size, Buf->getBufferStart());
+  if (Error E = Buf->commit())
+    return E;
+
+  return Error::success();
+}
+
 // given a file name, offset, and size, extract data into a code object file,
 // into file <SourceFile>-offset<Offset>-size<Size>.co
 Error object::extractOffloadBundleByURI(StringRef URIstr) {
@@ -260,11 +299,233 @@ static std::string formatWithCommas(unsigned long long Value) {
 }
 
 llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
-CompressedOffloadBundle::decompress(llvm::MemoryBufferRef &Input,
+CompressedOffloadBundle::compress(llvm::compression::Params P,
+                                  const llvm::MemoryBuffer &Input,
+                                  uint16_t Version, bool Verbose) {
+  if (!llvm::compression::zstd::isAvailable() &&
+      !llvm::compression::zlib::isAvailable())
+    return createStringError(llvm::inconvertibleErrorCode(),
+                             "Compression not supported");
+  llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
+                        OffloadBundlerTimerGroup);
+  if (Verbose)
+    HashTimer.startTimer();
+  llvm::MD5 Hash;
+  llvm::MD5::MD5Result Result;
+  Hash.update(Input.getBuffer());
+  Hash.final(Result);
+  uint64_t TruncatedHash = Result.low();
+  if (Verbose)
+    HashTimer.stopTimer();
+
+  SmallVector<uint8_t, 0> CompressedBuffer;
+  auto BufferUint8 = llvm::ArrayRef<uint8_t>(
+      reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
+      Input.getBuffer().size());
+  llvm::Timer CompressTimer("Compression Timer", "Compression time",
+                            OffloadBundlerTimerGroup);
+  if (Verbose)
+    CompressTimer.startTimer();
+  llvm::compression::compress(P, BufferUint8, CompressedBuffer);
+  if (Verbose)
+    CompressTimer.stopTimer();
+
+  uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
+
+  // Store sizes in 64-bit variables first
+  uint64_t UncompressedSize64 = Input.getBuffer().size();
+  uint64_t TotalFileSize64;
+
+  // Calculate total file size based on version
+  if (Version == 2) {
+    // For V2, ensure the sizes don't exceed 32-bit limit
+    if (UncompressedSize64 > std::numeric_limits<uint32_t>::max())
+      return createStringError(llvm::inconvertibleErrorCode(),
+                               "Uncompressed size exceeds version 2 limit");
+    if ((MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) +
+         sizeof(CompressionMethod) + sizeof(uint32_t) + sizeof(TruncatedHash) +
+         CompressedBuffer.size()) > std::numeric_limits<uint32_t>::max())
+      return createStringError(llvm::inconvertibleErrorCode(),
+                               "Total file size exceeds version 2 limit");
+
+    TotalFileSize64 = MagicNumber.size() + sizeof(uint32_t) + sizeof(Version) +
+                      sizeof(CompressionMethod) + sizeof(uint32_t) +
+                      sizeof(TruncatedHash) + CompressedBuffer.size();
+  } else { // Version 3
+    TotalFileSize64 = MagicNumber.size() + sizeof(uint64_t) + sizeof(Version) +
+                      sizeof(CompressionMethod) + sizeof(uint64_t) +
+                      sizeof(TruncatedHash) + CompressedBuffer.size();
+  }
+
+  SmallVector<char, 0> FinalBuffer;
+  llvm::raw_svector_ostream OS(FinalBuffer);
+  OS << MagicNumber;
+  OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
+  OS.write(reinterpret_cast<const char *>(&CompressionMethod),
+           sizeof(CompressionMethod));
+
+  // Write size fields according to version
+  if (Version == 2) {
+    uint32_t TotalFileSize32 = static_cast<uint32_t>(TotalFileSize64);
+    uint32_t UncompressedSize32 = static_cast<uint32_t>(UncompressedSize64);
+    OS.write(reinterpret_cast<const char *>(&TotalFileSize32),
+             sizeof(TotalFileSize32));
+    OS.write(reinterpret_cast<const char *>(&UncompressedSize32),
+             sizeof(UncompressedSize32));
+  } else { // Version 3
+    OS.write(reinterpret_cast<const char *>(&TotalFileSize64),
+             sizeof(TotalFileSize64));
+    OS.write(reinterpret_cast<const char *>(&UncompressedSize64),
+             sizeof(UncompressedSize64));
+  }
+
+  OS.write(reinterpret_cast<const char *>(&TruncatedHash),
+           sizeof(TruncatedHash));
+  OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
+           CompressedBuffer.size());
+
+  if (Verbose) {
+    auto MethodUsed =
+        P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
+    double CompressionRate =
+        static_cast<double>(UncompressedSize64) / CompressedBuffer.size();
+    double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
+    double CompressionSpeedMBs =
+        (UncompressedSize64 / (1024.0 * 1024.0)) / CompressionTimeSeconds;
+    llvm::errs() << "Compressed bundle format version: " << Version << "\n"
+                 << "Total file size (including headers): "
+                 << formatWithCommas(TotalFileSize64) << " bytes\n"
+                 << "Compression method used: " << MethodUsed << "\n"
+                 << "Compression level: " << P.level << "\n"
+                 << "Binary size before compression: "
+                 << formatWithCommas(UncompressedSize64) << " bytes\n"
+                 << "Binary size after compression: "
+                 << formatWithCommas(CompressedBuffer.size()) << " bytes\n"
+                 << "Compression rate: "
+                 << llvm::format("%.2lf", CompressionRate) << "\n"
+                 << "Compression ratio: "
+                 << llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
+                 << "Compression speed: "
+                 << llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
+                 << "Truncated MD5 hash: "
+                 << llvm::format_hex(TruncatedHash, 16) << "\n";
+  }
+
+  return llvm::MemoryBuffer::getMemBufferCopy(
+      llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
+}
+
+// Use packed structs to avoid padding, such that the structs map the serialized
+// format.
+LLVM_PACKED_START
+union RawCompressedBundleHeader {
+  struct CommonFields {
+    uint32_t Magic;
+    uint16_t Version;
+    uint16_t Method;
+  };
+
+  struct V1Header {
+    CommonFields Common;
+    uint32_t UncompressedFileSize;
+    uint64_t Hash;
+  };
+
+  struct V2Header {
+    CommonFields Common;
+    uint32_t FileSize;
+    uint32_t UncompressedFileSize;
+    uint64_t Hash;
+  };
+
+  struct V3Header {
+    CommonFields Common;
+    uint64_t FileSize;
+    uint64_t UncompressedFileSize;
+    uint64_t Hash;
+  };
+
+  CommonFields Common;
+  V1Header V1;
+  V2Header V2;
+  V3Header V3;
+};
+LLVM_PACKED_END
+
+// Helper method to get header size based on version
+static size_t getHeaderSize(uint16_t Version) {
+  switch (Version) {
+  case 1:
+    return sizeof(RawCompressedBundleHeader::V1Header);
+  case 2:
+    return sizeof(RawCompressedBundleHeader::V2Header);
+  case 3:
+    return sizeof(RawCompressedBundleHeader::V3Header);
+  default:
+    llvm_unreachable("Unsupported version");
+  }
+}
+
+Expected<CompressedOffloadBundle::CompressedBundleHeader>
+CompressedOffloadBundle::CompressedBundleHeader::tryParse(StringRef Blob) {
+  assert(Blob.size() >= sizeof(RawCompressedBundleHeader::CommonFields));
+  assert(llvm::identify_magic(Blob) ==
+         llvm::file_magic::offload_bundle_compressed);
+
+  RawCompressedBundleHeader Header;
+  memcpy(&Header, Blob.data(), std::min(Blob.size(), sizeof(Header)));
+
+  CompressedBundleHeader Normalized;
+  Normalized.Version = Header.Common.Version;
+
+  size_t RequiredSize = getHeaderSize(Normalized.Version);
+
+  if (Blob.size() < RequiredSize)
+    return createStringError(inconvertibleErrorCode(),
+                             "Compressed bundle header size too small");
+
+  switch (Normalized....
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants