diff --git a/llvm/test/tools/llvm-ir2vec/triplets.ll b/llvm/test/tools/llvm-ir2vec/triplets.ll new file mode 100644 index 0000000000000..fa5aaa895406f --- /dev/null +++ b/llvm/test/tools/llvm-ir2vec/triplets.ll @@ -0,0 +1,38 @@ +; RUN: llvm-ir2vec %s | FileCheck %s -check-prefix=TRIPLETS + +define i32 @simple_add(i32 %a, i32 %b) { +entry: + %add = add i32 %a, %b + ret i32 %add +} + +define i32 @simple_mul(i32 %x, i32 %y) { +entry: + %mul = mul i32 %x, %y + ret i32 %mul +} + +define i32 @test_function(i32 %arg1, i32 %arg2) { +entry: + %local1 = alloca i32, align 4 + %local2 = alloca i32, align 4 + store i32 %arg1, ptr %local1, align 4 + store i32 %arg2, ptr %local2, align 4 + %load1 = load i32, ptr %local1, align 4 + %load2 = load i32, ptr %local2, align 4 + %result = add i32 %load1, %load2 + ret i32 %result +} + +; TRIPLETS: Add IntegerTy Variable Variable +; TRIPLETS-NEXT: Ret VoidTy Variable +; TRIPLETS-NEXT: Mul IntegerTy Variable Variable +; TRIPLETS-NEXT: Ret VoidTy Variable +; TRIPLETS-NEXT: Alloca PointerTy Constant +; TRIPLETS-NEXT: Alloca PointerTy Constant +; TRIPLETS-NEXT: Store VoidTy Variable Pointer +; TRIPLETS-NEXT: Store VoidTy Variable Pointer +; TRIPLETS-NEXT: Load IntegerTy Pointer +; TRIPLETS-NEXT: Load IntegerTy Pointer +; TRIPLETS-NEXT: Add IntegerTy Variable Variable +; TRIPLETS-NEXT: Ret VoidTy Variable diff --git a/llvm/tools/llvm-ir2vec/CMakeLists.txt b/llvm/tools/llvm-ir2vec/CMakeLists.txt new file mode 100644 index 0000000000000..a4cf9690e86b5 --- /dev/null +++ b/llvm/tools/llvm-ir2vec/CMakeLists.txt @@ -0,0 +1,10 @@ +set(LLVM_LINK_COMPONENTS + Analysis + Core + IRReader + Support + ) + +add_llvm_tool(llvm-ir2vec + llvm-ir2vec.cpp + ) diff --git a/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp new file mode 100644 index 0000000000000..35e1c995fa4cc --- /dev/null +++ b/llvm/tools/llvm-ir2vec/llvm-ir2vec.cpp @@ -0,0 +1,150 @@ +//===- llvm-ir2vec.cpp - IR2Vec Embedding Generation Tool -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file implements the IR2Vec embedding generation tool. +/// +/// Currently supports triplet generation for vocabulary training. +/// Future updates will support embedding generation using trained vocabulary. +/// +/// Usage: llvm-ir2vec input.bc -o triplets.txt +/// +/// TODO: Add embedding generation mode with vocabulary support +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IR2Vec.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Errc.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; +using namespace ir2vec; + +#define DEBUG_TYPE "ir2vec" + +static cl::OptionCategory IR2VecToolCategory("IR2Vec Tool Options"); + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::Required, + cl::cat(IR2VecToolCategory)); + +static cl::opt OutputFilename("o", cl::desc("Output filename"), + cl::value_desc("filename"), + cl::init("-"), + cl::cat(IR2VecToolCategory)); + +namespace { + +/// Helper class for collecting IR information and generating triplets +class IR2VecTool { +private: + Module &M; + +public: + explicit IR2VecTool(Module &M) : M(M) {} + + /// Generate triplets for the entire module + void generateTriplets(raw_ostream &OS) const { + for (const Function &F : M) + generateTriplets(F, OS); + } + + /// Generate triplets for a single function + void generateTriplets(const Function &F, raw_ostream &OS) const { + if (F.isDeclaration()) + return; + + std::string LocalOutput; + raw_string_ostream LocalOS(LocalOutput); + + for (const BasicBlock &BB : F) + traverseBasicBlock(BB, LocalOS); + + LocalOS.flush(); + OS << LocalOutput; + } + +private: + /// Process a single basic block for triplet generation + void traverseBasicBlock(const BasicBlock &BB, raw_string_ostream &OS) const { + // Consider only non-debug and non-pseudo instructions + for (const auto &I : BB.instructionsWithoutDebug()) { + StringRef OpcStr = Vocabulary::getVocabKeyForOpcode(I.getOpcode()); + StringRef TypeStr = + Vocabulary::getVocabKeyForTypeID(I.getType()->getTypeID()); + + OS << '\n' << OpcStr << ' ' << TypeStr << ' '; + + LLVM_DEBUG(I.print(dbgs()); dbgs() << "\n"); + LLVM_DEBUG(I.getType()->print(dbgs()); dbgs() << " Type\n"); + + for (const Use &U : I.operands()) + OS << Vocabulary::getVocabKeyForOperandKind( + Vocabulary::getOperandKind(U.get())) + << ' '; + } + } +}; + +Error processModule(Module &M, raw_ostream &OS) { + IR2VecTool Tool(M); + Tool.generateTriplets(OS); + + return Error::success(); +} + +} // anonymous namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::HideUnrelatedOptions(IR2VecToolCategory); + cl::ParseCommandLineOptions( + argc, argv, + "IR2Vec - Triplet Generation Tool\n" + "Generates triplets for vocabulary training from LLVM IR.\n" + "Future updates will support embedding generation.\n\n" + "Usage:\n" + " llvm-ir2vec input.bc -o triplets.txt\n"); + + // Parse the input LLVM IR file + SMDiagnostic Err; + LLVMContext Context; + std::unique_ptr M = parseIRFile(InputFilename, Err, Context); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + + std::error_code EC; + raw_fd_ostream OS(OutputFilename, EC); + if (EC) { + errs() << "Error opening output file: " << EC.message() << "\n"; + return 1; + } + + if (Error Err = processModule(*M, OS)) { + handleAllErrors(std::move(Err), [&](const ErrorInfoBase &EIB) { + errs() << "Error: " << EIB.message() << "\n"; + }); + return 1; + } + + return 0; +}