31
31
32
32
#include " llvm/ADT/DenseMap.h"
33
33
#include " llvm/IR/PassManager.h"
34
+ #include " llvm/IR/Type.h"
34
35
#include " llvm/Support/CommandLine.h"
35
36
#include " llvm/Support/Compiler.h"
36
37
#include " llvm/Support/ErrorOr.h"
@@ -43,10 +44,10 @@ class Module;
43
44
class BasicBlock ;
44
45
class Instruction ;
45
46
class Function ;
46
- class Type ;
47
47
class Value ;
48
48
class raw_ostream ;
49
49
class LLVMContext ;
50
+ class IR2VecVocabAnalysis ;
50
51
51
52
// / IR2Vec computes two kinds of embeddings: Symbolic and Flow-aware.
52
53
// / Symbolic embeddings capture the "syntactic" and "statistical correlation"
@@ -128,9 +129,95 @@ struct Embedding {
128
129
129
130
using InstEmbeddingsMap = DenseMap<const Instruction *, Embedding>;
130
131
using BBEmbeddingsMap = DenseMap<const BasicBlock *, Embedding>;
131
- // FIXME: Current the keys are strings. This can be changed to
132
- // use integers for cheaper lookups.
133
- using Vocab = std::map<std::string, Embedding>;
132
+
133
+ // / Class for storing and accessing the IR2Vec vocabulary.
134
+ // / Encapsulates all vocabulary-related constants, logic, and access methods.
135
+ class Vocabulary {
136
+ friend class llvm ::IR2VecVocabAnalysis;
137
+ using VocabVector = std::vector<ir2vec::Embedding>;
138
+ VocabVector Vocab;
139
+ bool Valid = false ;
140
+
141
+ // / Operand kinds supported by IR2Vec Vocabulary
142
+ #define OPERAND_KINDS \
143
+ OPERAND_KIND (FunctionID, " Function" ) \
144
+ OPERAND_KIND (PointerID, " Pointer" ) \
145
+ OPERAND_KIND (ConstantID, " Constant" ) \
146
+ OPERAND_KIND (VariableID, " Variable" )
147
+
148
+ enum class OperandKind : unsigned {
149
+ #define OPERAND_KIND (Name, Str ) Name,
150
+ OPERAND_KINDS
151
+ #undef OPERAND_KIND
152
+ MaxOperandKind
153
+ };
154
+
155
+ #undef OPERAND_KINDS
156
+
157
+ // / Vocabulary layout constants
158
+ #define LAST_OTHER_INST (NUM ) static constexpr unsigned MaxOpcodes = NUM;
159
+ #include " llvm/IR/Instruction.def"
160
+ #undef LAST_OTHER_INST
161
+
162
+ static constexpr unsigned MaxTypeIDs = Type::TypeID::TargetExtTyID + 1 ;
163
+ static constexpr unsigned MaxOperandKinds =
164
+ static_cast <unsigned >(OperandKind::MaxOperandKind);
165
+
166
+ // / Helper function to get vocabulary key for a given OperandKind
167
+ static StringRef getVocabKeyForOperandKind (OperandKind Kind);
168
+
169
+ // / Helper function to classify an operand into OperandKind
170
+ static OperandKind getOperandKind (const Value *Op);
171
+
172
+ // / Helper function to get vocabulary key for a given TypeID
173
+ static StringRef getVocabKeyForTypeID (Type::TypeID TypeID);
174
+
175
+ public:
176
+ Vocabulary () = default ;
177
+ Vocabulary (VocabVector &&Vocab);
178
+
179
+ bool isValid () const ;
180
+ unsigned getDimension () const ;
181
+ size_t size () const ;
182
+
183
+ // / Accessors to get the embedding for a given entity.
184
+ const ir2vec::Embedding &operator [](unsigned Opcode) const ;
185
+ const ir2vec::Embedding &operator [](Type::TypeID TypeId) const ;
186
+ const ir2vec::Embedding &operator [](const Value *Arg) const ;
187
+
188
+ // / Const Iterator type aliases
189
+ using const_iterator = VocabVector::const_iterator;
190
+ const_iterator begin () const {
191
+ assert (Valid && " IR2Vec Vocabulary is invalid" );
192
+ return Vocab.begin ();
193
+ }
194
+
195
+ const_iterator cbegin () const {
196
+ assert (Valid && " IR2Vec Vocabulary is invalid" );
197
+ return Vocab.cbegin ();
198
+ }
199
+
200
+ const_iterator end () const {
201
+ assert (Valid && " IR2Vec Vocabulary is invalid" );
202
+ return Vocab.end ();
203
+ }
204
+
205
+ const_iterator cend () const {
206
+ assert (Valid && " IR2Vec Vocabulary is invalid" );
207
+ return Vocab.cend ();
208
+ }
209
+
210
+ // / Returns the string key for a given index position in the vocabulary.
211
+ // / This is useful for debugging or printing the vocabulary. Do not use this
212
+ // / for embedding generation as string based lookups are inefficient.
213
+ static StringRef getStringKey (unsigned Pos);
214
+
215
+ // / Create a dummy vocabulary for testing purposes.
216
+ static VocabVector createDummyVocabForTest (unsigned Dim = 1 );
217
+
218
+ bool invalidate (Module &M, const PreservedAnalyses &PA,
219
+ ModuleAnalysisManager::Invalidator &Inv) const ;
220
+ };
134
221
135
222
// / Embedder provides the interface to generate embeddings (vector
136
223
// / representations) for instructions, basic blocks, and functions. The
@@ -141,7 +228,7 @@ using Vocab = std::map<std::string, Embedding>;
141
228
class Embedder {
142
229
protected:
143
230
const Function &F;
144
- const Vocab &Vocabulary ;
231
+ const Vocabulary &Vocab ;
145
232
146
233
// / Dimension of the vector representation; captured from the input vocabulary
147
234
const unsigned Dimension;
@@ -156,7 +243,7 @@ class Embedder {
156
243
mutable BBEmbeddingsMap BBVecMap;
157
244
mutable InstEmbeddingsMap InstVecMap;
158
245
159
- LLVM_ABI Embedder (const Function &F, const Vocab &Vocabulary );
246
+ LLVM_ABI Embedder (const Function &F, const Vocabulary &Vocab );
160
247
161
248
// / Helper function to compute embeddings. It generates embeddings for all
162
249
// / the instructions and basic blocks in the function F. Logic of computing
@@ -167,16 +254,12 @@ class Embedder {
167
254
// / Specific to the kind of embeddings being computed.
168
255
virtual void computeEmbeddings (const BasicBlock &BB) const = 0;
169
256
170
- // / Lookup vocabulary for a given Key. If the key is not found, it returns a
171
- // / zero vector.
172
- LLVM_ABI Embedding lookupVocab (const std::string &Key) const ;
173
-
174
257
public:
175
258
virtual ~Embedder () = default ;
176
259
177
260
// / Factory method to create an Embedder object.
178
261
LLVM_ABI static std::unique_ptr<Embedder>
179
- create (IR2VecKind Mode, const Function &F, const Vocab &Vocabulary );
262
+ create (IR2VecKind Mode, const Function &F, const Vocabulary &Vocab );
180
263
181
264
// / Returns a map containing instructions and the corresponding embeddings for
182
265
// / the function F if it has been computed. If not, it computes the embeddings
@@ -202,56 +285,39 @@ class Embedder {
202
285
// / representations obtained from the Vocabulary.
203
286
class LLVM_ABI SymbolicEmbedder : public Embedder {
204
287
private:
205
- // / Utility function to compute the embedding for a given type.
206
- Embedding getTypeEmbedding (const Type *Ty) const ;
207
-
208
- // / Utility function to compute the embedding for a given operand.
209
- Embedding getOperandEmbedding (const Value *Op) const ;
210
-
211
288
void computeEmbeddings () const override ;
212
289
void computeEmbeddings (const BasicBlock &BB) const override ;
213
290
214
291
public:
215
- SymbolicEmbedder (const Function &F, const Vocab &Vocabulary )
216
- : Embedder(F, Vocabulary ) {
292
+ SymbolicEmbedder (const Function &F, const Vocabulary &Vocab )
293
+ : Embedder(F, Vocab ) {
217
294
FuncVector = Embedding (Dimension, 0 );
218
295
}
219
296
};
220
297
221
298
} // namespace ir2vec
222
299
223
- // / Class for storing the result of the IR2VecVocabAnalysis.
224
- class IR2VecVocabResult {
225
- ir2vec::Vocab Vocabulary;
226
- bool Valid = false ;
227
-
228
- public:
229
- IR2VecVocabResult () = default ;
230
- LLVM_ABI IR2VecVocabResult (ir2vec::Vocab &&Vocabulary);
231
-
232
- bool isValid () const { return Valid; }
233
- LLVM_ABI const ir2vec::Vocab &getVocabulary () const ;
234
- LLVM_ABI unsigned getDimension () const ;
235
- LLVM_ABI bool invalidate (Module &M, const PreservedAnalyses &PA,
236
- ModuleAnalysisManager::Invalidator &Inv) const ;
237
- };
238
-
239
300
// / This analysis provides the vocabulary for IR2Vec. The vocabulary provides a
240
301
// / mapping between an entity of the IR (like opcode, type, argument, etc.) and
241
302
// / its corresponding embedding.
242
303
class IR2VecVocabAnalysis : public AnalysisInfoMixin <IR2VecVocabAnalysis> {
243
- ir2vec::Vocab Vocabulary;
304
+ using VocabVector = std::vector<ir2vec::Embedding>;
305
+ using VocabMap = std::map<std::string, ir2vec::Embedding>;
306
+ VocabMap OpcVocab, TypeVocab, ArgVocab;
307
+ VocabVector Vocab;
308
+
244
309
Error readVocabulary ();
245
310
Error parseVocabSection (StringRef Key, const json::Value &ParsedVocabValue,
246
- ir2vec::Vocab &TargetVocab, unsigned &Dim);
311
+ VocabMap &TargetVocab, unsigned &Dim);
312
+ void generateNumMappedVocab ();
247
313
void emitError (Error Err, LLVMContext &Ctx);
248
314
249
315
public:
250
316
LLVM_ABI static AnalysisKey Key;
251
317
IR2VecVocabAnalysis () = default ;
252
- LLVM_ABI explicit IR2VecVocabAnalysis (const ir2vec::Vocab &Vocab);
253
- LLVM_ABI explicit IR2VecVocabAnalysis (ir2vec::Vocab &&Vocab);
254
- using Result = IR2VecVocabResult ;
318
+ LLVM_ABI explicit IR2VecVocabAnalysis (const VocabVector &Vocab);
319
+ LLVM_ABI explicit IR2VecVocabAnalysis (VocabVector &&Vocab);
320
+ using Result = ir2vec::Vocabulary ;
255
321
LLVM_ABI Result run (Module &M, ModuleAnalysisManager &MAM);
256
322
};
257
323
0 commit comments