-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[IR2Vec] Add embeddings mode to llvm-ir2vec tool #147844
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/svkeerthy/07-09-ir2vec_tool
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
; RUN: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-DEFAULT | ||
; RUN: llvm-ir2vec --mode=embeddings --level=func --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL | ||
; RUN: llvm-ir2vec --mode=embeddings --level=func --function=abc --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL-ABC | ||
; RUN: not llvm-ir2vec --mode=embeddings --level=func --function=def --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-DEF | ||
; RUN: llvm-ir2vec --mode=embeddings --level=bb --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL | ||
; RUN: llvm-ir2vec --mode=embeddings --level=bb --function=abc_repeat --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL-ABC-REPEAT | ||
; RUN: llvm-ir2vec --mode=embeddings --level=inst --function=abc_repeat --ir2vec-vocab-path=%S/../../Analysis/IR2Vec/Inputs/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-INST-LEVEL-ABC-REPEAT | ||
|
||
define dso_local noundef float @abc(i32 noundef %a, float noundef %b) #0 { | ||
entry: | ||
%a.addr = alloca i32, align 4 | ||
%b.addr = alloca float, align 4 | ||
store i32 %a, ptr %a.addr, align 4 | ||
store float %b, ptr %b.addr, align 4 | ||
%0 = load i32, ptr %a.addr, align 4 | ||
%1 = load i32, ptr %a.addr, align 4 | ||
%mul = mul nsw i32 %0, %1 | ||
%conv = sitofp i32 %mul to float | ||
%2 = load float, ptr %b.addr, align 4 | ||
%add = fadd float %conv, %2 | ||
ret float %add | ||
} | ||
|
||
define dso_local noundef float @abc_repeat(i32 noundef %a, float noundef %b) #0 { | ||
entry: | ||
%a.addr = alloca i32, align 4 | ||
%b.addr = alloca float, align 4 | ||
store i32 %a, ptr %a.addr, align 4 | ||
store float %b, ptr %b.addr, align 4 | ||
%0 = load i32, ptr %a.addr, align 4 | ||
%1 = load i32, ptr %a.addr, align 4 | ||
%mul = mul nsw i32 %0, %1 | ||
%conv = sitofp i32 %mul to float | ||
%2 = load float, ptr %b.addr, align 4 | ||
%add = fadd float %conv, %2 | ||
ret float %add | ||
} | ||
|
||
; CHECK-DEFAULT: Function: abc | ||
; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ] | ||
; CHECK-DEFAULT-NEXT: Function: abc_repeat | ||
; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ] | ||
|
||
; CHECK-FUNC-LEVEL: Function: abc | ||
; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ] | ||
; CHECK-FUNC-LEVEL-NEXT: Function: abc_repeat | ||
; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ] | ||
|
||
; CHECK-FUNC-LEVEL-ABC: Function: abc | ||
; CHECK-FUNC-LEVEL-NEXT-ABC: [ 878.00 889.00 900.00 ] | ||
|
||
; CHECK-FUNC-DEF: Error: Function 'def' not found | ||
|
||
; CHECK-BB-LEVEL: Function: abc | ||
; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ] | ||
; CHECK-BB-LEVEL-NEXT: Function: abc_repeat | ||
; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ] | ||
|
||
; CHECK-BB-LEVEL-ABC-REPEAT: Function: abc_repeat | ||
; CHECK-BB-LEVEL-ABC-REPEAT-NEXT: entry: [ 878.00 889.00 900.00 ] | ||
|
||
; CHECK-INST-LEVEL-ABC-REPEAT: Function: abc_repeat | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store i32 %a, ptr %a.addr, align 4 [ 97.00 98.00 99.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store float %b, ptr %b.addr, align 4 [ 97.00 98.00 99.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %0 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %1 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %mul = mul nsw i32 %0, %1 [ 49.00 50.00 51.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %conv = sitofp i32 %mul to float [ 130.00 131.00 132.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %2 = load float, ptr %b.addr, align 4 [ 94.00 95.00 96.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %add = fadd float %conv, %2 [ 40.00 41.00 42.00 ] | ||
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: ret float %add [ 1.00 2.00 3.00 ] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,18 @@ | |
/// \file | ||
/// This file implements the IR2Vec embedding generation tool. | ||
/// | ||
/// Currently supports triplet generation for vocabulary training. | ||
/// Future updates will support embedding generation using trained vocabulary. | ||
/// This tool provides two main functionalities: | ||
/// | ||
/// Usage: llvm-ir2vec input.bc -o triplets.txt | ||
/// 1. Triplet Generation Mode (--mode=triplets): | ||
/// Generates triplets (opcode, type, operands) for vocabulary training. | ||
/// Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt | ||
/// | ||
/// TODO: Add embedding generation mode with vocabulary support | ||
/// 2. Embedding Generation Mode (--mode=embeddings): | ||
/// Generates IR2Vec embeddings using a trained vocabulary. | ||
/// Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json | ||
/// --level=func input.bc -o embeddings.txt Levels: --level=inst | ||
/// (instructions), --level=bb (basic blocks), --level=func (functions) | ||
/// (See IR2Vec.cpp for more embedding generation options) | ||
/// | ||
//===----------------------------------------------------------------------===// | ||
|
||
|
@@ -24,6 +30,8 @@ | |
#include "llvm/IR/Instructions.h" | ||
#include "llvm/IR/LLVMContext.h" | ||
#include "llvm/IR/Module.h" | ||
#include "llvm/IR/PassInstrumentation.h" | ||
#include "llvm/IR/PassManager.h" | ||
#include "llvm/IR/Type.h" | ||
#include "llvm/IRReader/IRReader.h" | ||
#include "llvm/Support/CommandLine.h" | ||
|
@@ -34,7 +42,7 @@ | |
#include "llvm/Support/raw_ostream.h" | ||
|
||
using namespace llvm; | ||
using namespace ir2vec; | ||
using namespace llvm::ir2vec; | ||
|
||
#define DEBUG_TYPE "ir2vec" | ||
|
||
|
@@ -50,16 +58,63 @@ static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"), | |
cl::init("-"), | ||
cl::cat(IR2VecToolCategory)); | ||
|
||
enum ToolMode { | ||
TripletMode, // Generate triplets for vocabulary training | ||
EmbeddingMode // Generate embeddings using trained vocabulary | ||
}; | ||
|
||
static cl::opt<ToolMode> | ||
Mode("mode", cl::desc("Tool operation mode:"), | ||
cl::values(clEnumValN(TripletMode, "triplets", | ||
"Generate triplets for vocabulary training"), | ||
clEnumValN(EmbeddingMode, "embeddings", | ||
"Generate embeddings using trained vocabulary")), | ||
cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory)); | ||
|
||
static cl::opt<std::string> | ||
FunctionName("function", cl::desc("Process specific function only"), | ||
cl::value_desc("name"), cl::Optional, cl::init(""), | ||
cl::cat(IR2VecToolCategory)); | ||
|
||
enum EmbeddingLevel { | ||
InstructionLevel, // Generate instruction-level embeddings | ||
BasicBlockLevel, // Generate basic block-level embeddings | ||
FunctionLevel // Generate function-level embeddings | ||
}; | ||
|
||
static cl::opt<EmbeddingLevel> | ||
Level("level", cl::desc("Embedding generation level (for embedding mode):"), | ||
cl::values(clEnumValN(InstructionLevel, "inst", | ||
"Generate instruction-level embeddings"), | ||
clEnumValN(BasicBlockLevel, "bb", | ||
"Generate basic block-level embeddings"), | ||
clEnumValN(FunctionLevel, "func", | ||
"Generate function-level embeddings")), | ||
cl::init(FunctionLevel), cl::cat(IR2VecToolCategory)); | ||
|
||
namespace { | ||
|
||
/// Helper class for collecting IR information and generating triplets | ||
/// Helper class for collecting IR information and generating embeddings | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment seems incorrect given this can do embeddings and triplets? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the catch. Will fix it. |
||
class IR2VecTool { | ||
private: | ||
Module &M; | ||
ModuleAnalysisManager MAM; | ||
const Vocabulary *Vocab = nullptr; | ||
|
||
public: | ||
explicit IR2VecTool(Module &M) : M(M) {} | ||
|
||
/// Initialize the IR2Vec vocabulary analysis | ||
bool initializeVocabulary() { | ||
// Register and run the IR2Vec vocabulary analysis | ||
// The vocabulary file path is specified via --ir2vec-vocab-path global | ||
// option | ||
MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); | ||
MAM.registerPass([&] { return IR2VecVocabAnalysis(); }); | ||
Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M); | ||
return Vocab->isValid(); | ||
} | ||
|
||
/// Generate triplets for the entire module | ||
void generateTriplets(raw_ostream &OS) const { | ||
for (const Function &F : M) | ||
|
@@ -81,6 +136,70 @@ class IR2VecTool { | |
OS << LocalOutput; | ||
} | ||
|
||
/// Generate embeddings for the entire module | ||
void generateEmbeddings(raw_ostream &OS) const { | ||
if (!Vocab->isValid()) { | ||
OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n"; | ||
return; | ||
} | ||
|
||
for (const Function &F : M) | ||
generateEmbeddings(F, OS); | ||
} | ||
|
||
/// Generate embeddings for a single function | ||
void generateEmbeddings(const Function &F, raw_ostream &OS) const { | ||
if (F.isDeclaration()) { | ||
OS << "Function " << F.getName() << " is a declaration, skipping.\n"; | ||
return; | ||
} | ||
|
||
// Create embedder for this function | ||
assert(Vocab->isValid() && "Vocabulary is not valid"); | ||
auto Emb = Embedder::create(IR2VecKind::Symbolic, F, *Vocab); | ||
if (!Emb) { | ||
OS << "Error: Failed to create embedder for function " << F.getName() | ||
<< "\n"; | ||
return; | ||
} | ||
|
||
OS << "Function: " << F.getName() << "\n"; | ||
|
||
// Generate embeddings based on the specified level | ||
switch (Level) { | ||
case FunctionLevel: { | ||
Emb->getFunctionVector().print(OS); | ||
break; | ||
} | ||
case BasicBlockLevel: { | ||
const auto &BBVecMap = Emb->getBBVecMap(); | ||
for (const BasicBlock &BB : F) { | ||
auto It = BBVecMap.find(&BB); | ||
if (It != BBVecMap.end()) { | ||
OS << BB.getName() << ":"; | ||
It->second.print(OS); | ||
} | ||
} | ||
break; | ||
} | ||
case InstructionLevel: { | ||
const auto &InstMap = Emb->getInstVecMap(); | ||
for (const BasicBlock &BB : F) { | ||
for (const Instruction &I : BB) { | ||
auto It = InstMap.find(&I); | ||
if (It != InstMap.end()) { | ||
I.print(OS); | ||
It->second.print(OS); | ||
} | ||
} | ||
} | ||
break; | ||
} | ||
} | ||
|
||
// OS << "\n"; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove? |
||
} | ||
|
||
private: | ||
/// Process a single basic block for triplet generation | ||
void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const { | ||
|
@@ -105,8 +224,42 @@ class IR2VecTool { | |
|
||
Error processModule(Module &M, raw_ostream &OS) { | ||
IR2VecTool Tool(M); | ||
Tool.generateTriplets(OS); | ||
|
||
if (Mode == EmbeddingMode) { | ||
// Initialize vocabulary for embedding generation | ||
// Note: Requires --ir2vec-vocab-path option to be set | ||
if (!Tool.initializeVocabulary()) | ||
return createStringError( | ||
errc::invalid_argument, | ||
"Failed to initialize IR2Vec vocabulary. " | ||
"Make sure to specify --ir2vec-vocab-path for embedding mode."); | ||
|
||
if (!FunctionName.empty()) { | ||
// Process single function | ||
if (const Function *F = M.getFunction(FunctionName)) | ||
Tool.generateEmbeddings(*F, OS); | ||
else | ||
return createStringError(errc::invalid_argument, | ||
"Function '%s' not found", | ||
FunctionName.c_str()); | ||
} else { | ||
// Process all functions | ||
Tool.generateEmbeddings(OS); | ||
} | ||
} else { | ||
// Triplet generation mode - no vocabulary needed | ||
if (!FunctionName.empty()) | ||
// Process single function | ||
if (const Function *F = M.getFunction(FunctionName)) | ||
Tool.generateTriplets(*F, OS); | ||
else | ||
return createStringError(errc::invalid_argument, | ||
"Function '%s' not found", | ||
FunctionName.c_str()); | ||
else | ||
// Process all functions | ||
Tool.generateTriplets(OS); | ||
} | ||
return Error::success(); | ||
} | ||
|
||
|
@@ -117,11 +270,21 @@ int main(int argc, char **argv) { | |
cl::HideUnrelatedOptions(IR2VecToolCategory); | ||
cl::ParseCommandLineOptions( | ||
argc, argv, | ||
"IR2Vec - Triplet Generation Tool\n" | ||
"Generates triplets for vocabulary training from LLVM IR.\n" | ||
"Future updates will support embedding generation.\n\n" | ||
"IR2Vec - Embedding Generation Tool\n" | ||
"Generates embeddings for a given LLVM IR and " | ||
"supports triplet generation for vocabulary " | ||
"training and embedding generation.\n\n" | ||
"Usage:\n" | ||
" llvm-ir2vec input.bc -o triplets.txt\n"); | ||
" Triplet mode: llvm-ir2vec --mode=triplets input.bc\n" | ||
" Embedding mode: llvm-ir2vec --mode=embeddings " | ||
"--ir2vec-vocab-path=vocab.json --level=func input.bc\n" | ||
" Levels: --level=inst (instructions), --level=bb (basic blocks), " | ||
"--level=func (functions)\n"); | ||
|
||
// Validate command line options | ||
if (Mode == TripletMode && Level != FunctionLevel) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be a bit confusing if the user explicitly requests function level triplet mode. I think there's a way to see if the user has actually passed a flag with |
||
errs() << "Warning: --level option is ignored in triplet mode\n"; | ||
} | ||
|
||
// Parse the input LLVM IR file | ||
SMDiagnostic Err; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given you access
%S../../Analysis/IR2Vec/Inputs
so often, I'm wondering if it makes sense to add a lit substitution. Maybe something like%ir2vec_test_vocab
or something.