Skip to content

Commit 684d298

Browse files
committed
IR2Vec Tool Enhancements
1 parent 5f1f3fe commit 684d298

File tree

4 files changed

+250
-12
lines changed

4 files changed

+250
-12
lines changed

llvm/test/lit.cfg.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def get_asan_rtlib():
9393
config.substitutions.append(("%exeext", config.llvm_exe_ext))
9494
config.substitutions.append(("%llvm_src_root", config.llvm_src_root))
9595

96+
# Add IR2Vec test vocabulary path substitution
97+
config.substitutions.append(("%ir2vec_test_vocab_dir",
98+
os.path.join(config.test_source_root,
99+
"Analysis", "IR2Vec", "Inputs")))
96100

97101
lli_args = []
98102
# The target triple used by default by lli is the process target triple (some
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
; RUN: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-DEFAULT
2+
; RUN: llvm-ir2vec --mode=embeddings --level=func --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL
3+
; RUN: llvm-ir2vec --mode=embeddings --level=func --function=abc --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-FUNC-LEVEL-ABC
4+
; RUN: not llvm-ir2vec --mode=embeddings --level=func --function=def --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s 2>&1 | FileCheck %s -check-prefix=CHECK-FUNC-DEF
5+
; RUN: llvm-ir2vec --mode=embeddings --level=bb --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL
6+
; RUN: llvm-ir2vec --mode=embeddings --level=bb --function=abc_repeat --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-BB-LEVEL-ABC-REPEAT
7+
; RUN: llvm-ir2vec --mode=embeddings --level=inst --function=abc_repeat --ir2vec-vocab-path=%ir2vec_test_vocab_dir/dummy_3D_nonzero_opc_vocab.json %s | FileCheck %s -check-prefix=CHECK-INST-LEVEL-ABC-REPEAT
8+
9+
define dso_local noundef float @abc(i32 noundef %a, float noundef %b) #0 {
10+
entry:
11+
%a.addr = alloca i32, align 4
12+
%b.addr = alloca float, align 4
13+
store i32 %a, ptr %a.addr, align 4
14+
store float %b, ptr %b.addr, align 4
15+
%0 = load i32, ptr %a.addr, align 4
16+
%1 = load i32, ptr %a.addr, align 4
17+
%mul = mul nsw i32 %0, %1
18+
%conv = sitofp i32 %mul to float
19+
%2 = load float, ptr %b.addr, align 4
20+
%add = fadd float %conv, %2
21+
ret float %add
22+
}
23+
24+
define dso_local noundef float @abc_repeat(i32 noundef %a, float noundef %b) #0 {
25+
entry:
26+
%a.addr = alloca i32, align 4
27+
%b.addr = alloca float, align 4
28+
store i32 %a, ptr %a.addr, align 4
29+
store float %b, ptr %b.addr, align 4
30+
%0 = load i32, ptr %a.addr, align 4
31+
%1 = load i32, ptr %a.addr, align 4
32+
%mul = mul nsw i32 %0, %1
33+
%conv = sitofp i32 %mul to float
34+
%2 = load float, ptr %b.addr, align 4
35+
%add = fadd float %conv, %2
36+
ret float %add
37+
}
38+
39+
; CHECK-DEFAULT: Function: abc
40+
; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ]
41+
; CHECK-DEFAULT-NEXT: Function: abc_repeat
42+
; CHECK-DEFAULT-NEXT: [ 878.00 889.00 900.00 ]
43+
44+
; CHECK-FUNC-LEVEL: Function: abc
45+
; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ]
46+
; CHECK-FUNC-LEVEL-NEXT: Function: abc_repeat
47+
; CHECK-FUNC-LEVEL-NEXT: [ 878.00 889.00 900.00 ]
48+
49+
; CHECK-FUNC-LEVEL-ABC: Function: abc
50+
; CHECK-FUNC-LEVEL-NEXT-ABC: [ 878.00 889.00 900.00 ]
51+
52+
; CHECK-FUNC-DEF: Error: Function 'def' not found
53+
54+
; CHECK-BB-LEVEL: Function: abc
55+
; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ]
56+
; CHECK-BB-LEVEL-NEXT: Function: abc_repeat
57+
; CHECK-BB-LEVEL-NEXT: entry: [ 878.00 889.00 900.00 ]
58+
59+
; CHECK-BB-LEVEL-ABC-REPEAT: Function: abc_repeat
60+
; CHECK-BB-LEVEL-ABC-REPEAT-NEXT: entry: [ 878.00 889.00 900.00 ]
61+
62+
; CHECK-INST-LEVEL-ABC-REPEAT: Function: abc_repeat
63+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %a.addr = alloca i32, align 4 [ 91.00 92.00 93.00 ]
64+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %b.addr = alloca float, align 4 [ 91.00 92.00 93.00 ]
65+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store i32 %a, ptr %a.addr, align 4 [ 97.00 98.00 99.00 ]
66+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: store float %b, ptr %b.addr, align 4 [ 97.00 98.00 99.00 ]
67+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %0 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ]
68+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %1 = load i32, ptr %a.addr, align 4 [ 94.00 95.00 96.00 ]
69+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %mul = mul nsw i32 %0, %1 [ 49.00 50.00 51.00 ]
70+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %conv = sitofp i32 %mul to float [ 130.00 131.00 132.00 ]
71+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %2 = load float, ptr %b.addr, align 4 [ 94.00 95.00 96.00 ]
72+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: %add = fadd float %conv, %2 [ 40.00 41.00 42.00 ]
73+
; CHECK-INST-LEVEL-ABC-REPEAT-NEXT: ret float %add [ 1.00 2.00 3.00 ]

llvm/test/tools/llvm-ir2vec/triplets.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: llvm-ir2vec %s | FileCheck %s -check-prefix=TRIPLETS
1+
; RUN: llvm-ir2vec --mode=triplets %s | FileCheck %s -check-prefix=TRIPLETS
22

33
define i32 @simple_add(i32 %a, i32 %b) {
44
entry:

llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp

Lines changed: 172 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,18 @@
99
/// \file
1010
/// This file implements the IR2Vec embedding generation tool.
1111
///
12-
/// Currently supports triplet generation for vocabulary training.
13-
/// Future updates will support embedding generation using trained vocabulary.
12+
/// This tool provides two main functionalities:
1413
///
15-
/// Usage: llvm-ir2vec input.bc -o triplets.txt
14+
/// 1. Triplet Generation Mode (--mode=triplets):
15+
/// Generates triplets (opcode, type, operands) for vocabulary training.
16+
/// Usage: llvm-ir2vec --mode=triplets input.bc -o triplets.txt
1617
///
17-
/// TODO: Add embedding generation mode with vocabulary support
18+
/// 2. Embedding Generation Mode (--mode=embeddings):
19+
/// Generates IR2Vec embeddings using a trained vocabulary.
20+
/// Usage: llvm-ir2vec --mode=embeddings --ir2vec-vocab-path=vocab.json
21+
/// --level=func input.bc -o embeddings.txt Levels: --level=inst
22+
/// (instructions), --level=bb (basic blocks), --level=func (functions)
23+
/// (See IR2Vec.cpp for more embedding generation options)
1824
///
1925
//===----------------------------------------------------------------------===//
2026

@@ -24,6 +30,8 @@
2430
#include "llvm/IR/Instructions.h"
2531
#include "llvm/IR/LLVMContext.h"
2632
#include "llvm/IR/Module.h"
33+
#include "llvm/IR/PassInstrumentation.h"
34+
#include "llvm/IR/PassManager.h"
2735
#include "llvm/IR/Type.h"
2836
#include "llvm/IRReader/IRReader.h"
2937
#include "llvm/Support/CommandLine.h"
@@ -34,7 +42,7 @@
3442
#include "llvm/Support/raw_ostream.h"
3543

3644
using namespace llvm;
37-
using namespace ir2vec;
45+
using namespace llvm::ir2vec;
3846

3947
#define DEBUG_TYPE "ir2vec"
4048

@@ -50,16 +58,63 @@ static cl::opt<std::string> OutputFilename("o", cl::desc("Output filename"),
5058
cl::init("-"),
5159
cl::cat(IR2VecToolCategory));
5260

61+
enum ToolMode {
62+
TripletMode, // Generate triplets for vocabulary training
63+
EmbeddingMode // Generate embeddings using trained vocabulary
64+
};
65+
66+
static cl::opt<ToolMode>
67+
Mode("mode", cl::desc("Tool operation mode:"),
68+
cl::values(clEnumValN(TripletMode, "triplets",
69+
"Generate triplets for vocabulary training"),
70+
clEnumValN(EmbeddingMode, "embeddings",
71+
"Generate embeddings using trained vocabulary")),
72+
cl::init(EmbeddingMode), cl::cat(IR2VecToolCategory));
73+
74+
static cl::opt<std::string>
75+
FunctionName("function", cl::desc("Process specific function only"),
76+
cl::value_desc("name"), cl::Optional, cl::init(""),
77+
cl::cat(IR2VecToolCategory));
78+
79+
enum EmbeddingLevel {
80+
InstructionLevel, // Generate instruction-level embeddings
81+
BasicBlockLevel, // Generate basic block-level embeddings
82+
FunctionLevel // Generate function-level embeddings
83+
};
84+
85+
static cl::opt<EmbeddingLevel>
86+
Level("level", cl::desc("Embedding generation level (for embedding mode):"),
87+
cl::values(clEnumValN(InstructionLevel, "inst",
88+
"Generate instruction-level embeddings"),
89+
clEnumValN(BasicBlockLevel, "bb",
90+
"Generate basic block-level embeddings"),
91+
clEnumValN(FunctionLevel, "func",
92+
"Generate function-level embeddings")),
93+
cl::init(FunctionLevel), cl::cat(IR2VecToolCategory));
94+
5395
namespace {
5496

55-
/// Helper class for collecting IR information and generating triplets
97+
/// Helper class for collecting IR triplets and generating embeddings
5698
class IR2VecTool {
5799
private:
58100
Module &M;
101+
ModuleAnalysisManager MAM;
102+
const Vocabulary *Vocab = nullptr;
59103

60104
public:
61105
explicit IR2VecTool(Module &M) : M(M) {}
62106

107+
/// Initialize the IR2Vec vocabulary analysis
108+
bool initializeVocabulary() {
109+
// Register and run the IR2Vec vocabulary analysis
110+
// The vocabulary file path is specified via --ir2vec-vocab-path global
111+
// option
112+
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
113+
MAM.registerPass([&] { return IR2VecVocabAnalysis(); });
114+
Vocab = &MAM.getResult<IR2VecVocabAnalysis>(M);
115+
return Vocab->isValid();
116+
}
117+
63118
/// Generate triplets for the entire module
64119
void generateTriplets(raw_ostream &OS) const {
65120
for (const Function &F : M)
@@ -81,6 +136,68 @@ class IR2VecTool {
81136
OS << LocalOutput;
82137
}
83138

139+
/// Generate embeddings for the entire module
140+
void generateEmbeddings(raw_ostream &OS) const {
141+
if (!Vocab->isValid()) {
142+
OS << "Error: Vocabulary is not valid. IR2VecTool not initialized.\n";
143+
return;
144+
}
145+
146+
for (const Function &F : M)
147+
generateEmbeddings(F, OS);
148+
}
149+
150+
/// Generate embeddings for a single function
151+
void generateEmbeddings(const Function &F, raw_ostream &OS) const {
152+
if (F.isDeclaration()) {
153+
OS << "Function " << F.getName() << " is a declaration, skipping.\n";
154+
return;
155+
}
156+
157+
// Create embedder for this function
158+
assert(Vocab->isValid() && "Vocabulary is not valid");
159+
auto Emb = Embedder::create(IR2VecKind::Symbolic, F, *Vocab);
160+
if (!Emb) {
161+
OS << "Error: Failed to create embedder for function " << F.getName()
162+
<< "\n";
163+
return;
164+
}
165+
166+
OS << "Function: " << F.getName() << "\n";
167+
168+
// Generate embeddings based on the specified level
169+
switch (Level) {
170+
case FunctionLevel: {
171+
Emb->getFunctionVector().print(OS);
172+
break;
173+
}
174+
case BasicBlockLevel: {
175+
const auto &BBVecMap = Emb->getBBVecMap();
176+
for (const BasicBlock &BB : F) {
177+
auto It = BBVecMap.find(&BB);
178+
if (It != BBVecMap.end()) {
179+
OS << BB.getName() << ":";
180+
It->second.print(OS);
181+
}
182+
}
183+
break;
184+
}
185+
case InstructionLevel: {
186+
const auto &InstMap = Emb->getInstVecMap();
187+
for (const BasicBlock &BB : F) {
188+
for (const Instruction &I : BB) {
189+
auto It = InstMap.find(&I);
190+
if (It != InstMap.end()) {
191+
I.print(OS);
192+
It->second.print(OS);
193+
}
194+
}
195+
}
196+
break;
197+
}
198+
}
199+
}
200+
84201
private:
85202
/// Process a single basic block for triplet generation
86203
void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const {
@@ -105,8 +222,42 @@ class IR2VecTool {
105222

106223
Error processModule(Module &M, raw_ostream &OS) {
107224
IR2VecTool Tool(M);
108-
Tool.generateTriplets(OS);
109225

226+
if (Mode == EmbeddingMode) {
227+
// Initialize vocabulary for embedding generation
228+
// Note: Requires --ir2vec-vocab-path option to be set
229+
if (!Tool.initializeVocabulary())
230+
return createStringError(
231+
errc::invalid_argument,
232+
"Failed to initialize IR2Vec vocabulary. "
233+
"Make sure to specify --ir2vec-vocab-path for embedding mode.");
234+
235+
if (!FunctionName.empty()) {
236+
// Process single function
237+
if (const Function *F = M.getFunction(FunctionName))
238+
Tool.generateEmbeddings(*F, OS);
239+
else
240+
return createStringError(errc::invalid_argument,
241+
"Function '%s' not found",
242+
FunctionName.c_str());
243+
} else {
244+
// Process all functions
245+
Tool.generateEmbeddings(OS);
246+
}
247+
} else {
248+
// Triplet generation mode - no vocabulary needed
249+
if (!FunctionName.empty())
250+
// Process single function
251+
if (const Function *F = M.getFunction(FunctionName))
252+
Tool.generateTriplets(*F, OS);
253+
else
254+
return createStringError(errc::invalid_argument,
255+
"Function '%s' not found",
256+
FunctionName.c_str());
257+
else
258+
// Process all functions
259+
Tool.generateTriplets(OS);
260+
}
110261
return Error::success();
111262
}
112263

@@ -117,11 +268,21 @@ int main(int argc, char **argv) {
117268
cl::HideUnrelatedOptions(IR2VecToolCategory);
118269
cl::ParseCommandLineOptions(
119270
argc, argv,
120-
"IR2Vec - Triplet Generation Tool\n"
121-
"Generates triplets for vocabulary training from LLVM IR.\n"
122-
"Future updates will support embedding generation.\n\n"
271+
"IR2Vec - Embedding Generation Tool\n"
272+
"Generates embeddings for a given LLVM IR and "
273+
"supports triplet generation for vocabulary "
274+
"training and embedding generation.\n\n"
123275
"Usage:\n"
124-
" llvm-ir2vec input.bc -o triplets.txt\n");
276+
" Triplet mode: llvm-ir2vec --mode=triplets input.bc\n"
277+
" Embedding mode: llvm-ir2vec --mode=embeddings "
278+
"--ir2vec-vocab-path=vocab.json --level=func input.bc\n"
279+
" Levels: --level=inst (instructions), --level=bb (basic blocks), "
280+
"--level=func (functions)\n");
281+
282+
// Validate command line options
283+
if (Mode == TripletMode && Level != FunctionLevel) {
284+
errs() << "Warning: --level option is ignored in triplet mode\n";
285+
}
125286

126287
// Parse the input LLVM IR file
127288
SMDiagnostic Err;

0 commit comments

Comments
 (0)