diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index f5a4e450cf160..176cdaf7b5378 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -163,15 +163,18 @@ class Vocabulary { static constexpr unsigned MaxOperandKinds = static_cast(OperandKind::MaxOperandKind); + /// Helper function to get vocabulary key for a given Opcode + static StringRef getVocabKeyForOpcode(unsigned Opcode); + + /// Helper function to get vocabulary key for a given TypeID + static StringRef getVocabKeyForTypeID(Type::TypeID TypeID); + /// Helper function to get vocabulary key for a given OperandKind static StringRef getVocabKeyForOperandKind(OperandKind Kind); /// Helper function to classify an operand into OperandKind static OperandKind getOperandKind(const Value *Op); - /// Helper function to get vocabulary key for a given TypeID - static StringRef getVocabKeyForTypeID(Type::TypeID TypeID); - public: Vocabulary() = default; Vocabulary(VocabVector &&Vocab); diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index d3dc2e36fd23e..f97644b93a3d4 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -243,6 +243,17 @@ const ir2vec::Embedding &Vocabulary::operator[](const Value *Arg) const { return Vocab[MaxOpcodes + MaxTypeIDs + static_cast(ArgKind)]; } +StringRef Vocabulary::getVocabKeyForOpcode(unsigned Opcode) { + assert(Opcode >= 1 && Opcode <= MaxOpcodes && "Invalid opcode"); +#define HANDLE_INST(NUM, OPCODE, CLASS) \ + if (Opcode == NUM) { \ + return #OPCODE; \ + } +#include "llvm/IR/Instruction.def" +#undef HANDLE_INST + return "UnknownOpcode"; +} + StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) { switch (TypeID) { case Type::VoidTyID: @@ -280,6 +291,7 @@ StringRef Vocabulary::getVocabKeyForTypeID(Type::TypeID TypeID) { default: return "UnknownTy"; } + return "UnknownTy"; } // Operand kinds supported by IR2Vec - string mappings @@ -297,9 +309,9 @@ StringRef Vocabulary::getVocabKeyForOperandKind(Vocabulary::OperandKind Kind) { OPERAND_KINDS #undef OPERAND_KIND case Vocabulary::OperandKind::MaxOperandKind: - llvm_unreachable("Invalid OperandKind"); + return "UnknownOperand"; } - llvm_unreachable("Unknown OperandKind"); + return "UnknownOperand"; } #undef OPERAND_KINDS @@ -332,14 +344,8 @@ StringRef Vocabulary::getStringKey(unsigned Pos) { assert(Pos < MaxOpcodes + MaxTypeIDs + MaxOperandKinds && "Position out of bounds in vocabulary"); // Opcode - if (Pos < MaxOpcodes) { -#define HANDLE_INST(NUM, OPCODE, CLASS) \ - if (Pos == NUM - 1) { \ - return #OPCODE; \ - } -#include "llvm/IR/Instruction.def" -#undef HANDLE_INST - } + if (Pos < MaxOpcodes) + return getVocabKeyForOpcode(Pos + 1); // Type if (Pos < MaxOpcodes + MaxTypeIDs) return getVocabKeyForTypeID(static_cast(Pos - MaxOpcodes)); @@ -447,21 +453,18 @@ void IR2VecVocabAnalysis::generateNumMappedVocab() { // Handle Opcodes std::vector NumericOpcodeEmbeddings(Vocabulary::MaxOpcodes, Embedding(Dim, 0)); -#define HANDLE_INST(NUM, OPCODE, CLASS) \ - { \ - auto It = OpcVocab.find(#OPCODE); \ - if (It != OpcVocab.end()) \ - NumericOpcodeEmbeddings[NUM - 1] = It->second; \ - else \ - handleMissingEntity(#OPCODE); \ + for (unsigned Opcode : seq(0u, Vocabulary::MaxOpcodes)) { + StringRef VocabKey = Vocabulary::getVocabKeyForOpcode(Opcode + 1); + auto It = OpcVocab.find(VocabKey.str()); + if (It != OpcVocab.end()) + NumericOpcodeEmbeddings[Opcode] = It->second; + else + handleMissingEntity(VocabKey.str()); } -#include "llvm/IR/Instruction.def" -#undef HANDLE_INST Vocab.insert(Vocab.end(), NumericOpcodeEmbeddings.begin(), NumericOpcodeEmbeddings.end()); - // Handle Types using direct iteration through TypeID enum - // We iterate through all possible TypeID values and map them to embeddings + // Handle Types std::vector NumericTypeEmbeddings(Vocabulary::MaxTypeIDs, Embedding(Dim, 0)); for (unsigned TypeID : seq(0u, Vocabulary::MaxTypeIDs)) {