Skip to content

Commit 58ab005

Browse files
svkeerthymtrofin
andauthored
Adding IR2Vec as an analysis pass (#134004)
This PR introduces IR2Vec as an analysis pass. The changes include: - Logic for generating Symbolic encodings. - 75D learned vocabulary. - lit tests. Here is the link to the RFC - https://discourse.llvm.org/t/rfc-enhancing-mlgo-inlining-with-ir2vec-embeddings Acknowledgements: contributors - https://github.com/IITH-Compilers/IR2Vec/graphs/contributors --------- Co-authored-by: svkeerthy <venkatakeerthy@google.com> Co-authored-by: Mircea Trofin <mtrofin@google.com>
1 parent 68472a3 commit 58ab005

File tree

11 files changed

+774
-0
lines changed

11 files changed

+774
-0
lines changed

llvm/docs/MLGO.rst

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,96 @@ clang.
347347
TODO(mtrofin):
348348
- logging, and the use in interactive mode.
349349
- discuss an example (like the inliner)
350+
351+
IR2Vec Embeddings
352+
=================
353+
354+
IR2Vec is a program embedding approach designed specifically for LLVM IR. It
355+
is implemented as a function analysis pass in LLVM. The IR2Vec embeddings
356+
capture syntactic, semantic, and structural properties of the IR through
357+
learned representations. These representations are obtained as a JSON
358+
vocabulary that maps the entities of the IR (opcodes, types, operands) to
359+
n-dimensional floating point vectors (embeddings).
360+
361+
With IR2Vec, representation at different granularities of IR, such as
362+
instructions, functions, and basic blocks, can be obtained. Representations
363+
of loops and regions can be derived from these representations, which can be
364+
useful in different scenarios. The representations can be useful for various
365+
downstream tasks, including ML-guided compiler optimizations.
366+
367+
The core components are:
368+
- **Vocabulary**: A mapping from IR entities (opcodes, types, etc.) to their
369+
vector representations. This is managed by ``IR2VecVocabAnalysis``.
370+
- **Embedder**: A class (``ir2vec::Embedder``) that uses the vocabulary to
371+
compute embeddings for instructions, basic blocks, and functions.
372+
373+
Using IR2Vec
374+
------------
375+
376+
For generating embeddings, first the vocabulary should be obtained. Then, the
377+
embeddings can be computed and accessed via an ``ir2vec::Embedder`` instance.
378+
379+
1. **Get the Vocabulary**:
380+
In a ModulePass, get the vocabulary analysis result:
381+
382+
.. code-block:: c++
383+
384+
auto &VocabRes = MAM.getResult<IR2VecVocabAnalysis>(M);
385+
if (!VocabRes.isValid()) {
386+
// Handle error: vocabulary is not available or invalid
387+
return;
388+
}
389+
const ir2vec::Vocab &Vocabulary = VocabRes.getVocabulary();
390+
unsigned Dimension = VocabRes.getDimension();
391+
392+
Note that ``IR2VecVocabAnalysis`` pass is immutable.
393+
394+
2. **Create Embedder instance**:
395+
With the vocabulary, create an embedder for a specific function:
396+
397+
.. code-block:: c++
398+
399+
// Assuming F is an llvm::Function&
400+
// For example, using IR2VecKind::Symbolic:
401+
Expected<std::unique_ptr<ir2vec::Embedder>> EmbOrErr =
402+
ir2vec::Embedder::create(IR2VecKind::Symbolic, F, Vocabulary, Dimension);
403+
404+
if (auto Err = EmbOrErr.takeError()) {
405+
// Handle error in embedder creation
406+
return;
407+
}
408+
std::unique_ptr<ir2vec::Embedder> Emb = std::move(*EmbOrErr);
409+
410+
3. **Compute and Access Embeddings**:
411+
Call ``computeEmbeddings()`` on the embedder instance to compute the
412+
embeddings. Then the embeddings can be accessed using different getter
413+
methods. Currently, ``Embedder`` can generate embeddings at three levels:
414+
Instructions, Basic Blocks, and Functions.
415+
416+
.. code-block:: c++
417+
418+
Emb->computeEmbeddings();
419+
const ir2vec::Embedding &FuncVector = Emb->getFunctionVector();
420+
const ir2vec::InstEmbeddingsMap &InstVecMap = Emb->getInstVecMap();
421+
const ir2vec::BBEmbeddingsMap &BBVecMap = Emb->getBBVecMap();
422+
423+
// Example: Iterate over instruction embeddings
424+
for (const auto &Entry : InstVecMap) {
425+
const Instruction *Inst = Entry.getFirst();
426+
const ir2vec::Embedding &InstEmbedding = Entry.getSecond();
427+
// Use Inst and InstEmbedding
428+
}
429+
430+
4. **Working with Embeddings:**
431+
Embeddings are represented as ``std::vector<double>``. These
432+
vectors as features for machine learning models, compute similarity scores
433+
between different code snippets, or perform other analyses as needed.
434+
435+
Further Details
436+
---------------
437+
438+
For more detailed information about the IR2Vec algorithm, its parameters, and
439+
advanced usage, please refer to the original paper:
440+
`IR2Vec: LLVM IR Based Scalable Program Embeddings <https://doi.org/10.1145/3418463>`_.
441+
The LLVM source code for ``IR2Vec`` can also be explored to understand the
442+
implementation details.

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
//===- IR2Vec.h - Implementation of IR2Vec ----------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
4+
// Exceptions. See the LICENSE file for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
///
9+
/// \file
10+
/// This file defines the IR2Vec vocabulary analysis(IR2VecVocabAnalysis),
11+
/// the core ir2vec::Embedder interface for generating IR embeddings,
12+
/// and related utilities like the IR2VecPrinterPass.
13+
///
14+
/// Program Embeddings are typically or derived-from a learned
15+
/// representation of the program. Such embeddings are used to represent the
16+
/// programs as input to machine learning algorithms. IR2Vec represents the
17+
/// LLVM IR as embeddings.
18+
///
19+
/// The IR2Vec algorithm is described in the following paper:
20+
///
21+
/// IR2Vec: LLVM IR Based Scalable Program Embeddings, S. VenkataKeerthy,
22+
/// Rohit Aggarwal, Shalini Jain, Maunendra Sankar Desarkar, Ramakrishna
23+
/// Upadrasta, and Y. N. Srikant, ACM Transactions on Architecture and
24+
/// Code Optimization (TACO), 2020. https://doi.org/10.1145/3418463.
25+
/// https://arxiv.org/abs/1909.06228
26+
///
27+
//===----------------------------------------------------------------------===//
28+
29+
#ifndef LLVM_ANALYSIS_IR2VEC_H
30+
#define LLVM_ANALYSIS_IR2VEC_H
31+
32+
#include "llvm/ADT/DenseMap.h"
33+
#include "llvm/IR/PassManager.h"
34+
#include "llvm/Support/ErrorOr.h"
35+
#include <map>
36+
37+
namespace llvm {
38+
39+
class Module;
40+
class BasicBlock;
41+
class Instruction;
42+
class Function;
43+
class Type;
44+
class Value;
45+
class raw_ostream;
46+
47+
/// IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
48+
/// Symbolic embeddings capture the "syntactic" and "statistical correlation"
49+
/// of the IR entities. Flow-aware embeddings build on top of symbolic
50+
/// embeddings and additionally capture the flow information in the IR.
51+
/// IR2VecKind is used to specify the type of embeddings to generate.
52+
/// Currently, only Symbolic embeddings are supported.
53+
enum class IR2VecKind { Symbolic };
54+
55+
namespace ir2vec {
56+
using Embedding = std::vector<double>;
57+
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
58+
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
59+
// FIXME: Current the keys are strings. This can be changed to
60+
// use integers for cheaper lookups.
61+
using Vocab = std::map<std::string, Embedding>;
62+
63+
/// Embedder provides the interface to generate embeddings (vector
64+
/// representations) for instructions, basic blocks, and functions. The vector
65+
/// representations are generated using IR2Vec algorithms.
66+
///
67+
/// The Embedder class is an abstract class and it is intended to be
68+
/// subclassed for different IR2Vec algorithms like Symbolic and Flow-aware.
69+
class Embedder {
70+
protected:
71+
const Function &F;
72+
const Vocab &Vocabulary;
73+
74+
/// Dimension of the vector representation; captured from the input vocabulary
75+
const unsigned Dimension;
76+
77+
/// Weights for different entities (like opcode, arguments, types)
78+
/// in the IR instructions to generate the vector representation.
79+
const float OpcWeight, TypeWeight, ArgWeight;
80+
81+
// Utility maps - these are used to store the vector representations of
82+
// instructions, basic blocks and functions.
83+
Embedding FuncVector;
84+
BBEmbeddingsMap BBVecMap;
85+
InstEmbeddingsMap InstVecMap;
86+
87+
Embedder(const Function &F, const Vocab &Vocabulary, unsigned Dimension);
88+
89+
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
90+
/// zero vector.
91+
Embedding lookupVocab(const std::string &Key) const;
92+
93+
/// Adds two vectors: Dst += Src
94+
static void addVectors(Embedding &Dst, const Embedding &Src);
95+
96+
/// Adds Src vector scaled by Factor to Dst vector: Dst += Src * Factor
97+
static void addScaledVector(Embedding &Dst, const Embedding &Src,
98+
float Factor);
99+
100+
public:
101+
virtual ~Embedder() = default;
102+
103+
/// Top level function to compute embeddings. It generates embeddings for all
104+
/// the instructions and basic blocks in the function F. Logic of computing
105+
/// the embeddings is specific to the kind of embeddings being computed.
106+
virtual void computeEmbeddings() = 0;
107+
108+
/// Factory method to create an Embedder object.
109+
static Expected<std::unique_ptr<Embedder>> create(IR2VecKind Mode,
110+
const Function &F,
111+
const Vocab &Vocabulary,
112+
unsigned Dimension);
113+
114+
/// Returns a map containing instructions and the corresponding vector
115+
/// representations for a given module corresponding to the IR2Vec
116+
/// algorithm.
117+
const InstEmbeddingsMap &getInstVecMap() const { return InstVecMap; }
118+
119+
/// Returns a map containing basic block and the corresponding vector
120+
/// representations for a given module corresponding to the IR2Vec
121+
/// algorithm.
122+
const BBEmbeddingsMap &getBBVecMap() const { return BBVecMap; }
123+
124+
/// Returns the vector representation for a given function corresponding to
125+
/// the IR2Vec algorithm.
126+
const Embedding &getFunctionVector() const { return FuncVector; }
127+
};
128+
129+
/// Class for computing the Symbolic embeddings of IR2Vec.
130+
/// Symbolic embeddings are constructed based on the entity-level
131+
/// representations obtained from the Vocabulary.
132+
class SymbolicEmbedder : public Embedder {
133+
private:
134+
/// Utility function to compute the vector representation for a given basic
135+
/// block.
136+
Embedding computeBB2Vec(const BasicBlock &BB);
137+
138+
/// Utility function to compute the vector representation for a given type.
139+
Embedding getTypeEmbedding(const Type *Ty) const;
140+
141+
/// Utility function to compute the vector representation for a given
142+
/// operand.
143+
Embedding getOperandEmbedding(const Value *Op) const;
144+
145+
public:
146+
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,
147+
unsigned Dimension)
148+
: Embedder(F, Vocabulary, Dimension) {
149+
FuncVector = Embedding(Dimension, 0);
150+
}
151+
void computeEmbeddings() override;
152+
};
153+
154+
} // namespace ir2vec
155+
156+
/// Class for storing the result of the IR2VecVocabAnalysis.
157+
class IR2VecVocabResult {
158+
ir2vec::Vocab Vocabulary;
159+
bool Valid = false;
160+
161+
public:
162+
IR2VecVocabResult() = default;
163+
IR2VecVocabResult(ir2vec::Vocab &&Vocabulary);
164+
165+
bool isValid() const { return Valid; }
166+
const ir2vec::Vocab &getVocabulary() const;
167+
unsigned getDimension() const;
168+
bool invalidate(Module &M, const PreservedAnalyses &PA,
169+
ModuleAnalysisManager::Invalidator &Inv) const;
170+
};
171+
172+
/// This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
173+
/// mapping between an entity of the IR (like opcode, type, argument, etc.) and
174+
/// its corresponding embedding.
175+
class IR2VecVocabAnalysis : public AnalysisInfoMixin<IR2VecVocabAnalysis> {
176+
ir2vec::Vocab Vocabulary;
177+
Error readVocabulary();
178+
179+
public:
180+
static AnalysisKey Key;
181+
IR2VecVocabAnalysis() = default;
182+
using Result = IR2VecVocabResult;
183+
Result run(Module &M, ModuleAnalysisManager &MAM);
184+
};
185+
186+
/// This pass prints the IR2Vec embeddings for instructions, basic blocks, and
187+
/// functions.
188+
class IR2VecPrinterPass : public PassInfoMixin<IR2VecPrinterPass> {
189+
raw_ostream &OS;
190+
void printVector(const ir2vec::Embedding &Vec) const;
191+
192+
public:
193+
explicit IR2VecPrinterPass(raw_ostream &OS) : OS(OS) {}
194+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
195+
static bool isRequired() { return true; }
196+
};
197+
198+
} // namespace llvm
199+
200+
#endif // LLVM_ANALYSIS_IR2VEC_H

llvm/lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ add_llvm_component_library(LLVMAnalysis
7979
GlobalsModRef.cpp
8080
GuardUtils.cpp
8181
HeatUtils.cpp
82+
IR2Vec.cpp
8283
IRSimilarityIdentifier.cpp
8384
IVDescriptors.cpp
8485
IVUsers.cpp

0 commit comments

Comments
 (0)