Skip to content

[IR2Vec] Add embeddings mode to llvm-ir2vec tool #147844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/svkeerthy/07-09-ir2vec_tool
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions llvm/test/tools/llvm-ir2vec/embeddings.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
; RUN: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-DEFAULT
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given you access %S../../Analysis/IR2Vec/Inputs so often, I'm wondering if it makes sense to add a lit substitution. Maybe something like %ir2vec_test_vocab or something.

; RUN: llvm-ir2vec --mode=embeddings --level=func --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL
; RUN: llvm-ir2vec --mode=embeddings --level=func --function=abc --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL-ABC
; RUN: not llvm-ir2vec --mode=embeddings --level=func --function=def --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-DEF
; RUN: llvm-ir2vec --mode=embeddings --level=bb --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL
; RUN: llvm-ir2vec --mode=embeddings --level=bb --function=abc_repeat --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL-ABC-REPEAT
; RUN: llvm-ir2vec --mode=embeddings --level=inst --function=abc_repeat --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-INST-LEVEL-ABC-REPEAT

define dso_local noundef float @abc(i32 noundef %a, float noundef %b) #0 {
entry:
%a.addr = alloca i32, align 4
%b.addr = alloca float, align 4
store i32 %a, ptr %a.addr, align 4
store float %b, ptr %b.addr, align 4
%0 = load i32, ptr %a.addr, align 4
%1 = load i32, ptr %a.addr, align 4
%mul = mul nsw i32 %0, %1
%conv = sitofp i32 %mul to float
%2 = load float, ptr %b.addr, align 4
%add = fadd float %conv, %2
ret float %add
}

define dso_local noundef float @abc_repeat(i32 noundef %a, float noundef %b) #0 {
entry:
%a.addr = alloca i32, align 4
%b.addr = alloca float, align 4
store i32 %a, ptr %a.addr, align 4
store float %b, ptr %b.addr, align 4
%0 = load i32, ptr %a.addr, align 4
%1 = load i32, ptr %a.addr, align 4
%mul = mul nsw i32 %0, %1
%conv = sitofp i32 %mul to float
%2 = load float, ptr %b.addr, align 4
%add = fadd float %conv, %2
ret float %add
}

; CHECK-DEFAULT: Function: abc
; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ]
; CHECK-DEFAULT-NEXT: Function: abc_repeat
; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ]

; CHECK-FUNC-LEVEL: Function: abc
; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ]
; CHECK-FUNC-LEVEL-NEXT: Function: abc_repeat
; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ]

; CHECK-FUNC-LEVEL-ABC: Function: abc
; CHECK-FUNC-LEVEL-NEXT-ABC: [ 878.00 889.00 900.00 ]

; CHECK-FUNC-DEF: Error: Function 'def' not found

; CHECK-BB-LEVEL: Function: abc
; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ]
; CHECK-BB-LEVEL-NEXT: Function: abc_repeat
; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ]

; CHECK-BB-LEVEL-ABC-REPEAT: Function: abc_repeat
; CHECK-BB-LEVEL-ABC-REPEAT-NEXT: entry: [ 878.00 889.00 900.00 ]

; CHECK-INST-LEVEL-ABC-REPEAT: Function: abc_repeat
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store i32 %a, ptr %a.addr, align 4 [ 97.00 98.00 99.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store float %b, ptr %b.addr, align 4 [ 97.00 98.00 99.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %0 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %1 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %mul = mul nsw i32 %0, %1 [ 49.00 50.00 51.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %conv = sitofp i32 %mul to float [ 130.00 131.00 132.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %2 = load float, ptr %b.addr, align 4 [ 94.00 95.00 96.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %add = fadd float %conv, %2 [ 40.00 41.00 42.00 ]
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: ret float %add [ 1.00 2.00 3.00 ]
2 changes: 1 addition & 1 deletion llvm/test/tools/llvm-ir2vec/triplets.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
; RUN: llvm-ir2vec %s | FileCheck %s -check-prefix=TRIPLETS
; RUN: llvm-ir2vec --mode=triplets %s | FileCheck %s -check-prefix=TRIPLETS

define i32 @simple_add(i32 %a, i32 %b) {
entry:
Expand Down
185 changes: 174 additions & 11 deletions llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,18 @@
/// \file
/// This file implements the IR2Vec embedding generation tool.
///
/// Currently supports triplet generation for vocabulary training.
/// Future updates will support embedding generation using trained vocabulary.
/// This tool provides two main functionalities:
///
/// Usage: llvm-ir2vec input.bc -o triplets.txt
/// 1. Triplet Generation Mode (--mode=triplets):
/// Generates triplets (opcode, type, operands) for vocabulary training.
/// Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt
///
/// TODO: Add embedding generation mode with vocabulary support
/// 2. Embedding Generation Mode (--mode=embeddings):
/// Generates IR2Vec embeddings using a trained vocabulary.
/// Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json
/// --level=func input.bc -o embeddings.txt Levels: --level=inst
/// (instructions), --level=bb (basic blocks), --level=func (functions)
/// (See IR2Vec.cpp for more embedding generation options)
///
//===----------------------------------------------------------------------===//

Expand All @@ -24,6 +30,8 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/PassInstrumentation.h"
#include "llvm/IR/PassManager.h"
#include "llvm/IR/Type.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/CommandLine.h"
Expand All @@ -34,7 +42,7 @@
#include "llvm/Support/raw_ostream.h"

using namespace llvm;
using namespace ir2vec;
using namespace llvm::ir2vec;

#define DEBUG_TYPE "ir2vec"

Expand All @@ -50,16 +58,63 @@ static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
cl::init("-"),
cl::cat(IR2VecToolCategory));

enum ToolMode {
TripletMode, // Generate triplets for vocabulary training
EmbeddingMode // Generate embeddings using trained vocabulary
};

static cl::opt<ToolMode>
Mode("mode", cl::desc("Tool operation mode:"),
cl::values(clEnumValN(TripletMode, "triplets",
"Generate triplets for vocabulary training"),
clEnumValN(EmbeddingMode, "embeddings",
"Generate embeddings using trained vocabulary")),
cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory));

static cl::opt<std::string>
FunctionName("function", cl::desc("Process specific function only"),
cl::value_desc("name"), cl::Optional, cl::init(""),
cl::cat(IR2VecToolCategory));

enum EmbeddingLevel {
InstructionLevel, // Generate instruction-level embeddings
BasicBlockLevel, // Generate basic block-level embeddings
FunctionLevel // Generate function-level embeddings
};

static cl::opt<EmbeddingLevel>
Level("level", cl::desc("Embedding generation level (for embedding mode):"),
cl::values(clEnumValN(InstructionLevel, "inst",
"Generate instruction-level embeddings"),
clEnumValN(BasicBlockLevel, "bb",
"Generate basic block-level embeddings"),
clEnumValN(FunctionLevel, "func",
"Generate function-level embeddings")),
cl::init(FunctionLevel), cl::cat(IR2VecToolCategory));

namespace {

/// Helper class for collecting IR information and generating triplets
/// Helper class for collecting IR information and generating embeddings
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment seems incorrect given this can do embeddings and triplets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch. Will fix it.

class IR2VecTool {
private:
Module &M;
ModuleAnalysisManager MAM;
const Vocabulary *Vocab = nullptr;

public:
explicit IR2VecTool(Module &M) : M(M) {}

/// Initialize the IR2Vec vocabulary analysis
bool initializeVocabulary() {
// Register and run the IR2Vec vocabulary analysis
// The vocabulary file path is specified via --ir2vec-vocab-path global
// option
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
return Vocab->isValid();
}

/// Generate triplets for the entire module
void generateTriplets(raw_ostream &OS) const {
for (const Function &F : M)
Expand All @@ -81,6 +136,70 @@ class IR2VecTool {
OS << LocalOutput;
}

/// Generate embeddings for the entire module
void generateEmbeddings(raw_ostream &OS) const {
if (!Vocab->isValid()) {
OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n";
return;
}

for (const Function &F : M)
generateEmbeddings(F, OS);
}

/// Generate embeddings for a single function
void generateEmbeddings(const Function &F, raw_ostream &OS) const {
if (F.isDeclaration()) {
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
return;
}

// Create embedder for this function
assert(Vocab->isValid() && "Vocabulary is not valid");
auto Emb = Embedder::create(IR2VecKind::Symbolic, F, *Vocab);
if (!Emb) {
OS << "Error: Failed to create embedder for function " << F.getName()
<< "\n";
return;
}

OS << "Function: " << F.getName() << "\n";

// Generate embeddings based on the specified level
switch (Level) {
case FunctionLevel: {
Emb->getFunctionVector().print(OS);
break;
}
case BasicBlockLevel: {
const auto &BBVecMap = Emb->getBBVecMap();
for (const BasicBlock &BB : F) {
auto It = BBVecMap.find(&BB);
if (It != BBVecMap.end()) {
OS << BB.getName() << ":";
It->second.print(OS);
}
}
break;
}
case InstructionLevel: {
const auto &InstMap = Emb->getInstVecMap();
for (const BasicBlock &BB : F) {
for (const Instruction &I : BB) {
auto It = InstMap.find(&I);
if (It != InstMap.end()) {
I.print(OS);
It->second.print(OS);
}
}
}
break;
}
}

// OS << "\n";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove?

}

private:
/// Process a single basic block for triplet generation
void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const {
Expand All @@ -105,8 +224,42 @@ class IR2VecTool {

Error processModule(Module &M, raw_ostream &OS) {
IR2VecTool Tool(M);
Tool.generateTriplets(OS);

if (Mode == EmbeddingMode) {
// Initialize vocabulary for embedding generation
// Note: Requires --ir2vec-vocab-path option to be set
if (!Tool.initializeVocabulary())
return createStringError(
errc::invalid_argument,
"Failed to initialize IR2Vec vocabulary. "
"Make sure to specify --ir2vec-vocab-path for embedding mode.");

if (!FunctionName.empty()) {
// Process single function
if (const Function *F = M.getFunction(FunctionName))
Tool.generateEmbeddings(*F, OS);
else
return createStringError(errc::invalid_argument,
"Function '%s' not found",
FunctionName.c_str());
} else {
// Process all functions
Tool.generateEmbeddings(OS);
}
} else {
// Triplet generation mode - no vocabulary needed
if (!FunctionName.empty())
// Process single function
if (const Function *F = M.getFunction(FunctionName))
Tool.generateTriplets(*F, OS);
else
return createStringError(errc::invalid_argument,
"Function '%s' not found",
FunctionName.c_str());
else
// Process all functions
Tool.generateTriplets(OS);
}
return Error::success();
}

Expand All @@ -117,11 +270,21 @@ int main(int argc, char **argv) {
cl::HideUnrelatedOptions(IR2VecToolCategory);
cl::ParseCommandLineOptions(
argc, argv,
"IR2Vec - Triplet Generation Tool\n"
"Generates triplets for vocabulary training from LLVM IR.\n"
"Future updates will support embedding generation.\n\n"
"IR2Vec - Embedding Generation Tool\n"
"Generates embeddings for a given LLVM IR and "
"supports triplet generation for vocabulary "
"training and embedding generation.\n\n"
"Usage:\n"
" llvm-ir2vec input.bc -o triplets.txt\n");
" Triplet mode: llvm-ir2vec --mode=triplets input.bc\n"
" Embedding mode: llvm-ir2vec --mode=embeddings "
"--ir2vec-vocab-path=vocab.json --level=func input.bc\n"
" Levels: --level=inst (instructions), --level=bb (basic blocks), "
"--level=func (functions)\n");

// Validate command line options
if (Mode == TripletMode && Level != FunctionLevel) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a bit confusing if the user explicitly requests function level triplet mode. I think there's a way to see if the user has actually passed a flag with cl::opt, but I'm not sure.

errs() << "Warning: --level option is ignored in triplet mode\n";
}

// Parse the input LLVM IR file
SMDiagnostic Err;
Expand Down
Loading