31
31
32
32
#include " llvm/ADT/DenseMap.h"
33
33
#include " llvm/IR/PassManager.h"
34
+ #include " llvm/IR/Type.h"
34
35
#include " llvm/Support/CommandLine.h"
35
36
#include " llvm/Support/Compiler.h"
36
37
#include " llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
43
44
class BasicBlock ;
44
45
class Instruction ;
45
46
class Function ;
46
- class Type ;
47
47
class Value ;
48
48
class raw_ostream ;
49
49
class LLVMContext ;
50
+ class IR2VecVocabAnalysis ;
50
51
51
52
// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
52
53
// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -128,9 +129,73 @@ struct Embedding {
128
129
129
130
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
130
131
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
131
- // FIXME: Current the keys are strings. This can be changed to
132
- // use integers for cheaper lookups.
133
- using Vocab = std::map<std::string, Embedding>;
132
+
133
+ // / Class for storing and accessing the IR2Vec vocabulary.
134
+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
135
+ class Vocabulary {
136
+ friend class llvm ::IR2VecVocabAnalysis;
137
+ using VocabVector = std::vector<ir2vec::Embedding>;
138
+ VocabVector Vocab;
139
+ bool Valid = false ;
140
+
141
+ // / Operand kinds supported by IR2Vec Vocabulary
142
+ #define OPERAND_KINDS \
143
+ OPERAND_KIND (FunctionID, " Function" ) \
144
+ OPERAND_KIND (PointerID, " Pointer" ) \
145
+ OPERAND_KIND (ConstantID, " Constant" ) \
146
+ OPERAND_KIND (VariableID, " Variable" )
147
+
148
+ enum class OperandKind : unsigned {
149
+ #define OPERAND_KIND (Name, Str ) Name,
150
+ OPERAND_KINDS
151
+ #undef OPERAND_KIND
152
+ MaxOperandKind
153
+ };
154
+
155
+ #undef OPERAND_KINDS
156
+
157
+ // / Vocabulary layout constants
158
+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
159
+ #include " llvm/IR/Instruction.def"
160
+ #undef LAST_OTHER_INST
161
+
162
+ static constexpr unsigned MaxTypes = Type::TypeID::TargetExtTyID + 1 ;
163
+ static constexpr unsigned MaxOperandKinds =
164
+ static_cast <unsigned >(OperandKind::MaxOperandKind);
165
+
166
+ // / Helper function to get vocabulary key for a given OperandKind
167
+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
168
+
169
+ // / Helper function to classify an operand into OperandKind
170
+ static OperandKind getOperandKind (const Value *Op);
171
+
172
+ // / Helper function to get vocabulary key for a given TypeID
173
+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
174
+
175
+ public:
176
+ Vocabulary () = default ;
177
+ Vocabulary (VocabVector &&Vocab);
178
+
179
+ bool isValid () const ;
180
+ unsigned getDimension () const ;
181
+ unsigned size () const ;
182
+
183
+ const ir2vec::Embedding &at (unsigned Position) const ;
184
+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
185
+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
186
+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
187
+
188
+ // / Returns the string key for a given index position in the vocabulary.
189
+ // / This is useful for debugging or printing the vocabulary. Do not use this
190
+ // / for embedding generation as string based lookups are inefficient.
191
+ static StringRef getStringKey (unsigned Pos);
192
+
193
+ // / Create a dummy vocabulary for testing purposes.
194
+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
195
+
196
+ bool invalidate (Module &M, const PreservedAnalyses &PA,
197
+ ModuleAnalysisManager::Invalidator &Inv) const ;
198
+ };
134
199
135
200
// / Embedder provides the interface to generate embeddings (vector
136
201
// / representations) for instructions, basic blocks, and functions. The
@@ -141,7 +206,7 @@ using Vocab = std::map<std::string, Embedding>;
141
206
class Embedder {
142
207
protected:
143
208
const Function &F;
144
- const Vocab &Vocabulary ;
209
+ const Vocabulary &Vocab ;
145
210
146
211
// / Dimension of the vector representation; captured from the input vocabulary
147
212
const unsigned Dimension;
@@ -156,7 +221,7 @@ class Embedder {
156
221
mutable BBEmbeddingsMap BBVecMap;
157
222
mutable InstEmbeddingsMap InstVecMap;
158
223
159
- LLVM_ABI Embedder (const Function &F, const Vocab &Vocabulary );
224
+ LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab );
160
225
161
226
// / Helper function to compute embeddings. It generates embeddings for all
162
227
// / the instructions and basic blocks in the function F. Logic of computing
@@ -167,16 +232,12 @@ class Embedder {
167
232
// / Specific to the kind of embeddings being computed.
168
233
virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
169
234
170
- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
171
- // / zero vector.
172
- LLVM_ABI Embedding lookupVocab (const std::string &Key) const ;
173
-
174
235
public:
175
236
virtual ~Embedder () = default ;
176
237
177
238
// / Factory method to create an Embedder object.
178
239
LLVM_ABI static std::unique_ptr<Embedder>
179
- create (IR2VecKind Mode, const Function &F, const Vocab &Vocabulary );
240
+ create (IR2VecKind Mode, const Function &F, const Vocabulary &Vocab );
180
241
181
242
// / Returns a map containing instructions and the corresponding embeddings for
182
243
// / the function F if it has been computed. If not, it computes the embeddings
@@ -202,56 +263,40 @@ class Embedder {
202
263
// / representations obtained from the Vocabulary.
203
264
class LLVM_ABI SymbolicEmbedder : public Embedder {
204
265
private:
205
- // / Utility function to compute the embedding for a given type.
206
- Embedding getTypeEmbedding (const Type *Ty) const ;
207
-
208
- // / Utility function to compute the embedding for a given operand.
209
- Embedding getOperandEmbedding (const Value *Op) const ;
210
-
211
266
void computeEmbeddings () const override ;
212
267
void computeEmbeddings (const BasicBlock &BB) const override ;
213
268
214
269
public:
215
- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
216
- : Embedder(F, Vocabulary ) {
270
+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
271
+ : Embedder(F, Vocab ) {
217
272
FuncVector = Embedding (Dimension, 0 );
218
273
}
219
274
};
220
275
221
276
} // namespace ir2vec
222
277
223
- // / Class for storing the result of the IR2VecVocabAnalysis.
224
- class IR2VecVocabResult {
225
- ir2vec::Vocab Vocabulary;
226
- bool Valid = false ;
227
-
228
- public:
229
- IR2VecVocabResult () = default ;
230
- LLVM_ABI IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
231
-
232
- bool isValid () const { return Valid; }
233
- LLVM_ABI const ir2vec::Vocab &getVocabulary () const ;
234
- LLVM_ABI unsigned getDimension () const ;
235
- LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
236
- ModuleAnalysisManager::Invalidator &Inv) const ;
237
- };
238
-
239
278
// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
240
279
// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
241
280
// / its corresponding embedding.
242
281
class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
243
- ir2vec::Vocab Vocabulary;
282
+ using VocabVector = std::vector<ir2vec::Embedding>;
283
+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
284
+ VocabMap OpcVocab, TypeVocab, ArgVocab;
285
+ VocabVector Vocab;
286
+
287
+ unsigned Dim = 0 ;
244
288
Error readVocabulary ();
245
289
Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
246
- ir2vec::Vocab &TargetVocab, unsigned &Dim);
290
+ VocabMap &TargetVocab, unsigned &Dim);
291
+ void generateNumMappedVocab ();
247
292
void emitError (Error Err, LLVMContext &Ctx);
248
293
249
294
public:
250
295
LLVM_ABI static AnalysisKey Key;
251
296
IR2VecVocabAnalysis () = default ;
252
- LLVM_ABI explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
253
- LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
254
- using Result = IR2VecVocabResult ;
297
+ LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
298
+ LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
299
+ using Result = ir2vec::Vocabulary ;
255
300
LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
256
301
};
257
302
0 commit comments