diff --git a/.gitignore b/.gitignore index 4bce7875f7b..66512eb0505 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,5 @@ *.a *.elf *.out +default.profraw +default.profexport diff --git a/Cargo.lock b/Cargo.lock index 5e5cbeda91f..0f4862e8fdb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1243,8 +1243,7 @@ dependencies = [ [[package]] name = "inkwell" version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f4fcb4a4fa0b8f7b4178e24e6317d6f8b95ab500d8e6e1bd4283b6860e369c1" +source = "git+https://github.com/corbanvilla/inkwell.git?branch=v2.0#583a6c20522cf8a743b9daf214f254424bf14eb2" dependencies = [ "either", "inkwell_internals", @@ -1257,8 +1256,7 @@ dependencies = [ [[package]] name = "inkwell_internals" version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b185e7d068d6820411502efa14d8fbf010750485399402156b72dd2a548ef8e9" +source = "git+https://github.com/corbanvilla/inkwell.git?branch=v2.0#583a6c20522cf8a743b9daf214f254424bf14eb2" dependencies = [ "proc-macro2 1.0.67", "quote 1.0.33", @@ -2241,6 +2239,24 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc_llvm" +version = "0.0.0" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "rustc_llvm_coverage" +version = "0.1.0" +dependencies = [ + "inkwell", + "libc", + "llvm-sys", + "rustc_llvm", +] + [[package]] name = "rustix" version = "0.37.25" @@ -2296,6 +2312,7 @@ dependencies = [ "plc_xml", "pretty_assertions", "regex", + "rustc_llvm_coverage", "serde", "serde_json", "serial_test", diff --git a/Cargo.toml b/Cargo.toml index ebaef5a49eb..e5e455fc305 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ plc_source = { path = "./compiler/plc_source" } plc_ast = { path = "./compiler/plc_ast" } plc_util = { path = "./compiler/plc_util" } plc_diagnostics = { path = "./compiler/plc_diagnostics" } +rustc_llvm_coverage = { path = "./compiler/rustc_llvm_coverage" } logos = "0.12.0" thiserror = "1.0" clap = { version = "3.0", features = ["derive"] } @@ -69,11 +70,13 @@ members = [ "compiler/plc_util", "compiler/plc_xml", "compiler/plc_derive", + "compiler/rustc_llvm", + "compiler/rustc_llvm_coverage", ] default-members = [".", "compiler/plc_driver", "compiler/plc_xml"] [workspace.dependencies] -inkwell = { version = "0.2", features = ["llvm14-0"] } +inkwell = { git = "https://github.com/corbanvilla/inkwell.git", branch = "v2.0", version = "0.2", features = ["llvm14-0"] } encoding_rs = "0.8" encoding_rs_io = "0.1" log = "0.4" diff --git a/compiler/plc_source/src/source_location.rs b/compiler/plc_source/src/source_location.rs index beafedd4d77..adbc1c5c8b2 100644 --- a/compiler/plc_source/src/source_location.rs +++ b/compiler/plc_source/src/source_location.rs @@ -296,6 +296,19 @@ impl SourceLocation { true } } + + // TOOD - there's probably a better way to do this + pub fn get_start_end(&self) -> (usize, usize, usize, usize) { + let span = self.get_span(); + if let CodeSpan::Range(range) = span { + let (start_line, start_col, end_line, end_col) = + // TODO - rusty devs will understand what needs to be done here w/ line + 1 + (range.start.line + 1, range.start.column + 1, range.end.line + 1, range.end.column + 1); + (start_line, start_col, end_line, end_col) + } else { + panic!("Error: expected CodeSpan::Range, found {:?}", span); + } + } } /** diff --git a/compiler/rustc_llvm/Cargo.toml b/compiler/rustc_llvm/Cargo.toml new file mode 100644 index 00000000000..34556df3c6d --- /dev/null +++ b/compiler/rustc_llvm/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "rustc_llvm" +version = "0.0.0" +edition = "2021" + +[features] +static-libstdcpp = [] +emscripten = [] + +[dependencies] +libc = "0.2.73" + +[build-dependencies] +cc = "1.0.69" diff --git a/compiler/rustc_llvm/README.md b/compiler/rustc_llvm/README.md new file mode 100644 index 00000000000..204eb75d94a --- /dev/null +++ b/compiler/rustc_llvm/README.md @@ -0,0 +1,5 @@ +# `rustc-llvm` + +This package serves to wrap some of the LLVM functions which are not natively exposed as part of the LLVM C-API, in a Rust-friendly way. + +This code is taken directly from the [Rust compiler source code (version 1.64.0)](https://github.com/rust-lang/rust/tree/a55dd71d5fb0ec5a6a3a9e8c27b2127ba491ce52), which is the last version of the Rust compiler to use LLVM 14 (which is currently the version used by `ruSTy`). The Rust compiler uses this code to interface with LLVM in order to add code coverage instrumentation to Rust binaries, among other features. diff --git a/compiler/rustc_llvm/build.rs b/compiler/rustc_llvm/build.rs new file mode 100644 index 00000000000..345e2c7cc0d --- /dev/null +++ b/compiler/rustc_llvm/build.rs @@ -0,0 +1,356 @@ +use std::env; +use std::ffi::{OsStr, OsString}; +use std::fmt::Display; +use std::path::{Path, PathBuf}; +use std::process::{Command, Stdio}; + +const OPTIONAL_COMPONENTS: &[&str] = &[ + "x86", + "arm", + "aarch64", + "amdgpu", + "avr", + "m68k", + "mips", + "powerpc", + "systemz", + "jsbackend", + "webassembly", + "msp430", + "sparc", + "nvptx", + "hexagon", + "riscv", + "bpf", +]; + +const REQUIRED_COMPONENTS: &[&str] = + &["ipo", "bitreader", "bitwriter", "linker", "asmparser", "lto", "coverage", "instrumentation"]; + +fn detect_llvm_link() -> (&'static str, &'static str) { + // Force the link mode we want, preferring static by default, but + // possibly overridden by `configure --enable-llvm-link-shared`. + if tracked_env_var_os("LLVM_LINK_SHARED").is_some() { + ("dylib", "--link-shared") + } else { + ("static", "--link-static") + } +} + +// Because Cargo adds the compiler's dylib path to our library search path, llvm-config may +// break: the dylib path for the compiler, as of this writing, contains a copy of the LLVM +// shared library, which means that when our freshly built llvm-config goes to load it's +// associated LLVM, it actually loads the compiler's LLVM. In particular when building the first +// compiler (i.e., in stage 0) that's a problem, as the compiler's LLVM is likely different from +// the one we want to use. As such, we restore the environment to what bootstrap saw. This isn't +// perfect -- we might actually want to see something from Cargo's added library paths -- but +// for now it works. +fn restore_library_path() { + // let key = tracked_env_var_os("REAL_LIBRARY_PATH_VAR").expect("REAL_LIBRARY_PATH_VAR"); + // if let Some(env) = tracked_env_var_os("REAL_LIBRARY_PATH") { + // env::set_var(&key, &env); + // } else { + // env::remove_var(&key); + // } +} + +/// Reads an environment variable and adds it to dependencies. +/// Supposed to be used for all variables except those set for build scripts by cargo +/// +fn tracked_env_var_os + Display>(key: K) -> Option { + println!("cargo:rerun-if-env-changed={}", key); + env::var_os(key) +} + +fn rerun_if_changed_anything_in_dir(dir: &Path) { + let mut stack = + dir.read_dir().unwrap().map(|e| e.unwrap()).filter(|e| &*e.file_name() != ".git").collect::>(); + while let Some(entry) = stack.pop() { + let path = entry.path(); + if entry.file_type().unwrap().is_dir() { + stack.extend(path.read_dir().unwrap().map(|e| e.unwrap())); + } else { + println!("cargo:rerun-if-changed={}", path.display()); + } + } +} + +#[track_caller] +fn output(cmd: &mut Command) -> String { + let output = match cmd.stderr(Stdio::inherit()).output() { + Ok(status) => status, + Err(e) => { + println!("\n\nfailed to execute command: {:?}\nerror: {}\n\n", cmd, e); + std::process::exit(1); + } + }; + if !output.status.success() { + panic!( + "command did not execute successfully: {:?}\n\ + expected success, got: {}", + cmd, output.status + ); + } + String::from_utf8(output.stdout).unwrap() +} + +fn main() { + for component in REQUIRED_COMPONENTS.iter().chain(OPTIONAL_COMPONENTS.iter()) { + println!("cargo:rustc-check-cfg=values(llvm_component,\"{}\")", component); + } + + if tracked_env_var_os("RUST_CHECK").is_some() { + // If we're just running `check`, there's no need for LLVM to be built. + return; + } + + restore_library_path(); + + let target = env::var("TARGET").expect("TARGET was not set"); + let llvm_config = + tracked_env_var_os("LLVM_CONFIG").map(|x| Some(PathBuf::from(x))).unwrap_or_else(|| { + if let Some(dir) = tracked_env_var_os("CARGO_TARGET_DIR").map(PathBuf::from) { + let to_test = + dir.parent().unwrap().parent().unwrap().join(&target).join("llvm/bin/llvm-config"); + if Command::new(&to_test).output().is_ok() { + return Some(to_test); + } + } + None + }); + + if let Some(llvm_config) = &llvm_config { + println!("cargo:rerun-if-changed={}", llvm_config.display()); + } + let llvm_config = llvm_config.unwrap_or_else(|| PathBuf::from("llvm-config")); + + // Test whether we're cross-compiling LLVM. This is a pretty rare case + // currently where we're producing an LLVM for a different platform than + // what this build script is currently running on. + // + // In that case, there's no guarantee that we can actually run the target, + // so the build system works around this by giving us the LLVM_CONFIG for + // the host platform. This only really works if the host LLVM and target + // LLVM are compiled the same way, but for us that's typically the case. + // + // We *want* detect this cross compiling situation by asking llvm-config + // what its host-target is. If that's not the TARGET, then we're cross + // compiling. Unfortunately `llvm-config` seems either be buggy, or we're + // misconfiguring it, because the `i686-pc-windows-gnu` build of LLVM will + // report itself with a `--host-target` of `x86_64-pc-windows-gnu`. This + // tricks us into thinking we're doing a cross build when we aren't, so + // havoc ensues. + // + // In any case, if we're cross compiling, this generally just means that we + // can't trust all the output of llvm-config because it might be targeted + // for the host rather than the target. As a result a bunch of blocks below + // are gated on `if !is_crossed` + let target = env::var("TARGET").expect("TARGET was not set"); + let host = env::var("HOST").expect("HOST was not set"); + let is_crossed = target != host; + + let components = output(Command::new(&llvm_config).arg("--components")); + let mut components = components.split_whitespace().collect::>(); + components.retain(|c| OPTIONAL_COMPONENTS.contains(c) || REQUIRED_COMPONENTS.contains(c)); + + for component in REQUIRED_COMPONENTS { + if !components.contains(component) { + panic!("require llvm component {} but wasn't found", component); + } + } + + for component in components.iter() { + println!("cargo:rustc-cfg=llvm_component=\"{}\"", component); + } + + // Link in our own LLVM shims, compiled with the same flags as LLVM + let mut cmd = Command::new(&llvm_config); + cmd.arg("--cxxflags"); + let cxxflags = output(&mut cmd); + let mut cfg = cc::Build::new(); + cfg.warnings(false); + for flag in cxxflags.split_whitespace() { + // Ignore flags like `-m64` when we're doing a cross build + if is_crossed && flag.starts_with("-m") { + continue; + } + + if flag.starts_with("-flto") { + continue; + } + + // -Wdate-time is not supported by the netbsd cross compiler + if is_crossed && target.contains("netbsd") && flag.contains("date-time") { + continue; + } + + // Include path contains host directory, replace it with target + if is_crossed && flag.starts_with("-I") { + cfg.flag(&flag.replace(&host, &target)); + continue; + } + + cfg.flag(flag); + } + + for component in &components { + let mut flag = String::from("LLVM_COMPONENT_"); + flag.push_str(&component.to_uppercase()); + cfg.define(&flag, None); + } + + if tracked_env_var_os("LLVM_RUSTLLVM").is_some() { + cfg.define("LLVM_RUSTLLVM", None); + } + + if tracked_env_var_os("LLVM_NDEBUG").is_some() { + cfg.define("NDEBUG", None); + cfg.debug(false); + } + + rerun_if_changed_anything_in_dir(Path::new("llvm-wrapper")); + cfg.file("llvm-wrapper/CoverageMappingWrapper.cpp") + .cpp(true) + .cpp_link_stdlib(None) // we handle this below + .compile("llvm-wrapper"); + + let (llvm_kind, llvm_link_arg) = detect_llvm_link(); + + // Link in all LLVM libraries, if we're using the "wrong" llvm-config then + // we don't pick up system libs because unfortunately they're for the host + // of llvm-config, not the target that we're attempting to link. + let mut cmd = Command::new(&llvm_config); + cmd.arg(llvm_link_arg).arg("--libs"); + + if !is_crossed { + cmd.arg("--system-libs"); + } else if target.contains("windows-gnu") { + println!("cargo:rustc-link-lib=shell32"); + println!("cargo:rustc-link-lib=uuid"); + } else if target.contains("netbsd") || target.contains("haiku") || target.contains("darwin") { + println!("cargo:rustc-link-lib=z"); + } + cmd.args(&components); + + for lib in output(&mut cmd).split_whitespace() { + let name = if let Some(stripped) = lib.strip_prefix("-l") { + stripped + } else if let Some(stripped) = lib.strip_prefix('-') { + stripped + } else if Path::new(lib).exists() { + // On MSVC llvm-config will print the full name to libraries, but + // we're only interested in the name part + let name = Path::new(lib).file_name().unwrap().to_str().unwrap(); + name.trim_end_matches(".lib") + } else if lib.ends_with(".lib") { + // Some MSVC libraries just come up with `.lib` tacked on, so chop + // that off + lib.trim_end_matches(".lib") + } else { + continue; + }; + + // Don't need or want this library, but LLVM's CMake build system + // doesn't provide a way to disable it, so filter it here even though we + // may or may not have built it. We don't reference anything from this + // library and it otherwise may just pull in extra dependencies on + // libedit which we don't want + if name == "LLVMLineEditor" { + continue; + } + + let kind = if name.starts_with("LLVM") { llvm_kind } else { "dylib" }; + println!("cargo:rustc-link-lib={}={}", kind, name); + } + + // LLVM ldflags + // + // If we're a cross-compile of LLVM then unfortunately we can't trust these + // ldflags (largely where all the LLVM libs are located). Currently just + // hack around this by replacing the host triple with the target and pray + // that those -L directories are the same! + let mut cmd = Command::new(&llvm_config); + cmd.arg(llvm_link_arg).arg("--ldflags"); + for lib in output(&mut cmd).split_whitespace() { + if is_crossed { + if let Some(stripped) = lib.strip_prefix("-LIBPATH:") { + println!("cargo:rustc-link-search=native={}", stripped.replace(&host, &target)); + } else if let Some(stripped) = lib.strip_prefix("-L") { + println!("cargo:rustc-link-search=native={}", stripped.replace(&host, &target)); + } + } else if let Some(stripped) = lib.strip_prefix("-LIBPATH:") { + println!("cargo:rustc-link-search=native={}", stripped); + } else if let Some(stripped) = lib.strip_prefix("-l") { + println!("cargo:rustc-link-lib={}", stripped); + } else if let Some(stripped) = lib.strip_prefix("-L") { + println!("cargo:rustc-link-search=native={}", stripped); + } + } + + // Some LLVM linker flags (-L and -l) may be needed even when linking + // rustc_llvm, for example when using static libc++, we may need to + // manually specify the library search path and -ldl -lpthread as link + // dependencies. + let llvm_linker_flags = tracked_env_var_os("LLVM_LINKER_FLAGS"); + if let Some(s) = llvm_linker_flags { + for lib in s.into_string().unwrap().split_whitespace() { + if let Some(stripped) = lib.strip_prefix("-l") { + println!("cargo:rustc-link-lib={}", stripped); + } else if let Some(stripped) = lib.strip_prefix("-L") { + println!("cargo:rustc-link-search=native={}", stripped); + } + } + } + + let llvm_static_stdcpp = tracked_env_var_os("LLVM_STATIC_STDCPP"); + let llvm_use_libcxx = tracked_env_var_os("LLVM_USE_LIBCXX"); + + let stdcppname = if target.contains("openbsd") { + if target.contains("sparc64") { + "estdc++" + } else { + "c++" + } + } else if target.contains("darwin") || target.contains("freebsd") || target.contains("windows-gnullvm") { + "c++" + } else if target.contains("netbsd") && llvm_static_stdcpp.is_some() { + // NetBSD uses a separate library when relocation is required + "stdc++_pic" + } else if llvm_use_libcxx.is_some() { + "c++" + } else { + "stdc++" + }; + + // RISC-V GCC erroneously requires libatomic for sub-word + // atomic operations. FreeBSD uses Clang as its system + // compiler and provides no libatomic in its base system so + // does not want this. + if !target.contains("freebsd") && target.starts_with("riscv") { + println!("cargo:rustc-link-lib=atomic"); + } + + // C++ runtime library + if !target.contains("msvc") { + if let Some(s) = llvm_static_stdcpp { + assert!(!cxxflags.contains("stdlib=libc++")); + let path = PathBuf::from(s); + println!("cargo:rustc-link-search=native={}", path.parent().unwrap().display()); + if target.contains("windows") { + println!("cargo:rustc-link-lib=static:-bundle={}", stdcppname); + } else { + println!("cargo:rustc-link-lib=static={}", stdcppname); + } + } else if cxxflags.contains("stdlib=libc++") { + println!("cargo:rustc-link-lib=c++"); + } else { + println!("cargo:rustc-link-lib={}", stdcppname); + } + } + + // Libstdc++ depends on pthread which Rust doesn't link on MinGW + // since nothing else requires it. + if target.ends_with("windows-gnu") { + println!("cargo:rustc-link-lib=static:-bundle=pthread"); + } +} diff --git a/compiler/rustc_llvm/llvm-wrapper/.editorconfig b/compiler/rustc_llvm/llvm-wrapper/.editorconfig new file mode 100644 index 00000000000..865cd45f708 --- /dev/null +++ b/compiler/rustc_llvm/llvm-wrapper/.editorconfig @@ -0,0 +1,6 @@ +[*.{h,cpp}] +end_of_line = lf +insert_final_newline = true +charset = utf-8 +indent_style = space +indent_size = 2 diff --git a/compiler/rustc_llvm/llvm-wrapper/CoverageMappingWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/CoverageMappingWrapper.cpp new file mode 100644 index 00000000000..8ff8b1d079c --- /dev/null +++ b/compiler/rustc_llvm/llvm-wrapper/CoverageMappingWrapper.cpp @@ -0,0 +1,122 @@ +#include "LLVMWrapper.h" +#include "llvm/ProfileData/Coverage/CoverageMapping.h" +#include "llvm/ProfileData/Coverage/CoverageMappingWriter.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +#include + +using namespace llvm; + +struct LLVMRustCounterMappingRegion { + coverage::Counter Count; + coverage::Counter FalseCount; + uint32_t FileID; + uint32_t ExpandedFileID; + uint32_t LineStart; + uint32_t ColumnStart; + uint32_t LineEnd; + uint32_t ColumnEnd; + coverage::CounterMappingRegion::RegionKind Kind; +}; + +extern "C" void LLVMRustCoverageWriteFilenamesSectionToBuffer( + const char* const Filenames[], + size_t FilenamesLen, + RustStringRef BufferOut) { +#if LLVM_VERSION_GE(13,0) + SmallVector FilenameRefs; + for (size_t i = 0; i < FilenamesLen; i++) { + FilenameRefs.push_back(std::string(Filenames[i])); + } +#else + SmallVector FilenameRefs; + for (size_t i = 0; i < FilenamesLen; i++) { + FilenameRefs.push_back(StringRef(Filenames[i])); + } +#endif + auto FilenamesWriter = coverage::CoverageFilenamesSectionWriter( + makeArrayRef(FilenameRefs)); + RawRustStringOstream OS(BufferOut); + FilenamesWriter.write(OS); +} + +extern "C" void LLVMRustCoverageWriteMappingToBuffer( + const unsigned *VirtualFileMappingIDs, + unsigned NumVirtualFileMappingIDs, + const coverage::CounterExpression *Expressions, + unsigned NumExpressions, + LLVMRustCounterMappingRegion *RustMappingRegions, + unsigned NumMappingRegions, + RustStringRef BufferOut) { + // Convert from FFI representation to LLVM representation. + SmallVector MappingRegions; + MappingRegions.reserve(NumMappingRegions); + for (const auto &Region : makeArrayRef(RustMappingRegions, NumMappingRegions)) { + MappingRegions.emplace_back( + Region.Count, Region.FalseCount, Region.FileID, Region.ExpandedFileID, + Region.LineStart, Region.ColumnStart, Region.LineEnd, Region.ColumnEnd, + Region.Kind); + } + auto CoverageMappingWriter = coverage::CoverageMappingWriter( + makeArrayRef(VirtualFileMappingIDs, NumVirtualFileMappingIDs), + makeArrayRef(Expressions, NumExpressions), + MappingRegions); + RawRustStringOstream OS(BufferOut); + CoverageMappingWriter.write(OS); +} + +extern "C" LLVMValueRef LLVMRustCoverageCreatePGOFuncNameVar(LLVMValueRef F, const char *FuncName) { + StringRef FuncNameRef(FuncName); + return wrap(createPGOFuncNameVar(*cast(unwrap(F)), FuncNameRef)); +} + +extern "C" uint64_t LLVMRustCoverageHashCString(const char *StrVal) { + StringRef StrRef(StrVal); + return IndexedInstrProf::ComputeHash(StrRef); +} + +extern "C" uint64_t LLVMRustCoverageHashByteArray( + const char *Bytes, + unsigned NumBytes) { + StringRef StrRef(Bytes, NumBytes); + return IndexedInstrProf::ComputeHash(StrRef); +} + +static void WriteSectionNameToString(LLVMModuleRef M, + InstrProfSectKind SK, + RustStringRef Str) { + Triple TargetTriple(unwrap(M)->getTargetTriple()); + auto name = getInstrProfSectionName(SK, TargetTriple.getObjectFormat()); + RawRustStringOstream OS(Str); + OS << name; +} + +extern "C" void LLVMRustCoverageWriteMapSectionNameToString(LLVMModuleRef M, + RustStringRef Str) { + WriteSectionNameToString(M, IPSK_covmap, Str); +} + +extern "C" void LLVMRustCoverageWriteFuncSectionNameToString(LLVMModuleRef M, + RustStringRef Str) { + WriteSectionNameToString(M, IPSK_covfun, Str); +} + +extern "C" void LLVMRustCoverageWriteMappingVarNameToString(RustStringRef Str) { + auto name = getCoverageMappingVarName(); + RawRustStringOstream OS(Str); + OS << name; +} + +extern "C" uint32_t LLVMRustCoverageMappingVersion() { +#if LLVM_VERSION_GE(13, 0) + return coverage::CovMapVersion::Version6; +#else + return coverage::CovMapVersion::Version5; +#endif +} + +extern "C" void LLVMRustAppendToUsed(LLVMModuleRef M, GlobalValue *G) { + appendToUsed(*unwrap(M), makeArrayRef(G)); +} diff --git a/compiler/rustc_llvm/llvm-wrapper/LLVMWrapper.h b/compiler/rustc_llvm/llvm-wrapper/LLVMWrapper.h new file mode 100644 index 00000000000..015c1c52bef --- /dev/null +++ b/compiler/rustc_llvm/llvm-wrapper/LLVMWrapper.h @@ -0,0 +1,121 @@ +#include "llvm-c/BitReader.h" +#include "llvm-c/Core.h" +#include "llvm-c/Object.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Triple.h" +#include "llvm/Analysis/Lint.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DynamicLibrary.h" +#include "llvm/Support/FormattedStream.h" +#include "llvm/Support/Host.h" +#include "llvm/Support/Memory.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/TargetSelect.h" +#include "llvm/Support/Timer.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Vectorize.h" + +#define LLVM_VERSION_GE(major, minor) \ + (LLVM_VERSION_MAJOR > (major) || \ + LLVM_VERSION_MAJOR == (major) && LLVM_VERSION_MINOR >= (minor)) + +#define LLVM_VERSION_LT(major, minor) (!LLVM_VERSION_GE((major), (minor))) + +#include "llvm/IR/LegacyPassManager.h" + +#include "llvm/Bitcode/BitcodeReader.h" +#include "llvm/Bitcode/BitcodeWriter.h" + +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/IRPrintingPasses.h" +#include "llvm/Linker/Linker.h" + +extern "C" void LLVMRustSetLastError(const char *); + +enum class LLVMRustResult { Success, Failure }; + +enum LLVMRustAttribute { + AlwaysInline = 0, + ByVal = 1, + Cold = 2, + InlineHint = 3, + MinSize = 4, + Naked = 5, + NoAlias = 6, + NoCapture = 7, + NoInline = 8, + NonNull = 9, + NoRedZone = 10, + NoReturn = 11, + NoUnwind = 12, + OptimizeForSize = 13, + ReadOnly = 14, + SExt = 15, + StructRet = 16, + UWTable = 17, + ZExt = 18, + InReg = 19, + SanitizeThread = 20, + SanitizeAddress = 21, + SanitizeMemory = 22, + NonLazyBind = 23, + OptimizeNone = 24, + ReturnsTwice = 25, + ReadNone = 26, + InaccessibleMemOnly = 27, + SanitizeHWAddress = 28, + WillReturn = 29, + StackProtectReq = 30, + StackProtectStrong = 31, + StackProtect = 32, + NoUndef = 33, + SanitizeMemTag = 34, + NoCfCheck = 35, + ShadowCallStack = 36, + AllocSize = 37, +#if LLVM_VERSION_GE(15, 0) + AllocatedPointer = 38, + AllocAlign = 39, +#endif +}; + +typedef struct OpaqueRustString *RustStringRef; +typedef struct LLVMOpaqueTwine *LLVMTwineRef; +typedef struct LLVMOpaqueSMDiagnostic *LLVMSMDiagnosticRef; + +extern "C" void LLVMRustStringWriteImpl(RustStringRef Str, const char *Ptr, + size_t Size); + +class RawRustStringOstream : public llvm::raw_ostream { + RustStringRef Str; + uint64_t Pos; + + void write_impl(const char *Ptr, size_t Size) override { + LLVMRustStringWriteImpl(Str, Ptr, Size); + Pos += Size; + } + + uint64_t current_pos() const override { return Pos; } + +public: + explicit RawRustStringOstream(RustStringRef Str) : Str(Str), Pos(0) {} + + ~RawRustStringOstream() { + // LLVM requires this. + flush(); + } +}; diff --git a/compiler/rustc_llvm/llvm-wrapper/README b/compiler/rustc_llvm/llvm-wrapper/README new file mode 100644 index 00000000000..e1c6dd07d2b --- /dev/null +++ b/compiler/rustc_llvm/llvm-wrapper/README @@ -0,0 +1,16 @@ +This directory currently contains some LLVM support code. This will generally +be sent upstream to LLVM in time; for now it lives here. + +NOTE: the LLVM C++ ABI is subject to between-version breakage and must *never* +be exposed to Rust. To allow for easy auditing of that, all Rust-exposed types +must be typedef-ed as "LLVMXyz", or "LLVMRustXyz" if they were defined here. + +Functions that return a failure status and leave the error in +the LLVM last error should return an LLVMRustResult rather than an +int or anything to avoid confusion. + +When translating enums, add a single `Other` variant as the first +one to allow for new variants to be added. It should abort when used +as an input. + +All other types must not be typedef-ed as such. diff --git a/compiler/rustc_llvm/src/lib.rs b/compiler/rustc_llvm/src/lib.rs new file mode 100644 index 00000000000..8b137891791 --- /dev/null +++ b/compiler/rustc_llvm/src/lib.rs @@ -0,0 +1 @@ + diff --git a/compiler/rustc_llvm_coverage/Cargo.toml b/compiler/rustc_llvm_coverage/Cargo.toml new file mode 100644 index 00000000000..0def989f990 --- /dev/null +++ b/compiler/rustc_llvm_coverage/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "rustc_llvm_coverage" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +libc = "0.2.73" +rustc_llvm = { path = "../rustc_llvm" } +inkwell = { git = "https://github.com/corbanvilla/inkwell.git", branch = "v2.0", version = "0.2", features = ["llvm14-0"] } +llvm-sys = { package = "llvm-sys", version = "140.0.2" } diff --git a/compiler/rustc_llvm_coverage/src/ffi.rs b/compiler/rustc_llvm_coverage/src/ffi.rs new file mode 100644 index 00000000000..ccac1390f40 --- /dev/null +++ b/compiler/rustc_llvm_coverage/src/ffi.rs @@ -0,0 +1,60 @@ +#![allow(dead_code, unused_variables)] + +// This FFI links to the static library built by `rustc_llvm` +// +// Function interface definitions are taken from [here](https://github.com/rust-lang/rust/blob/84c898d65adf2f39a5a98507f1fe0ce10a2b8dbc/compiler/rustc_codegen_llvm/src/llvm/ffi.rs#L1864). +// +use libc::{c_char, c_uint, size_t}; +use std::slice; + +use super::types::*; +use inkwell::values::PointerValue; +use llvm_sys::prelude::{LLVMModuleRef, LLVMValueRef}; + +/// Appending to a Rust string -- used by RawRustStringOstream. +#[no_mangle] +pub unsafe extern "C" fn LLVMRustStringWriteImpl(sr: &RustString, ptr: *const c_char, size: size_t) { + let slice = slice::from_raw_parts(ptr as *const u8, size as usize); + + sr.bytes.borrow_mut().extend_from_slice(slice); +} + +#[link(name = "llvm-wrapper", kind = "static")] +extern "C" { + #[allow(improper_ctypes)] + pub fn LLVMRustCoverageWriteFilenamesSectionToBuffer( + Filenames: *const *const c_char, + FilenamesLen: size_t, + BufferOut: &RustString, + ); + + #[allow(improper_ctypes)] + pub fn LLVMRustCoverageWriteMappingToBuffer( + VirtualFileMappingIDs: *const c_uint, + NumVirtualFileMappingIDs: c_uint, + Expressions: *const CounterExpression, + NumExpressions: c_uint, + MappingRegions: *const CounterMappingRegion, + NumMappingRegions: c_uint, + BufferOut: &RustString, + ); + + pub fn LLVMRustCoverageCreatePGOFuncNameVar(F: LLVMValueRef, FuncName: *const c_char) -> LLVMValueRef; + pub fn LLVMRustCoverageHashCString(StrVal: *const c_char) -> u64; + pub fn LLVMRustCoverageHashByteArray(Bytes: *const c_char, NumBytes: size_t) -> u64; + + #[allow(improper_ctypes)] + pub fn LLVMRustCoverageWriteMapSectionNameToString(M: LLVMModuleRef, Str: &RustString); + + #[allow(improper_ctypes)] + pub fn LLVMRustCoverageWriteFuncSectionNameToString(M: LLVMModuleRef, Str: &RustString); + + #[allow(improper_ctypes)] + pub fn LLVMRustCoverageWriteMappingVarNameToString(Str: &RustString); + + pub fn LLVMRustCoverageMappingVersion() -> u32; + + #[allow(improper_ctypes)] + pub fn LLVMRustAppendToUsed(M: LLVMModuleRef, V: PointerValue); + +} diff --git a/compiler/rustc_llvm_coverage/src/interfaces.rs b/compiler/rustc_llvm_coverage/src/interfaces.rs new file mode 100644 index 00000000000..dd886938214 --- /dev/null +++ b/compiler/rustc_llvm_coverage/src/interfaces.rs @@ -0,0 +1,162 @@ +/* + * Many of the functions in this file have been adapted from the + * `rustc` implementation of LLVM code coverage. + * + * https://github.com/rust-lang/rust/blob/84c898d65adf2f39a5a98507f1fe0ce10a2b8dbc/compiler/rustc_codegen_llvm/src/coverageinfo/mod.rs#L220-L221 + * + * TODO - Consider updating functions to reflect configurations in latest Rust (not 1.64) + * https://github.com/rust-lang/rust/blob/master/compiler/rustc_codegen_llvm/src/coverageinfo/mod.rs +*/ + +const VAR_ALIGN_BYTES: u32 = 8; + +use std::string::FromUtf8Error; + +use super::*; +use crate::types::*; + +use inkwell::{ + module::Linkage, + values::{AsValueRef, FunctionValue, GlobalValue, StructValue}, + GlobalVisibility, +}; + +use libc::c_uint; +use std::ffi::CString; + +use inkwell::module::Module; + +/* == TODO - Refactor these helpers out */ +pub fn build_string(sr: &RustString) -> Result { + String::from_utf8(sr.bytes.borrow().clone()) +} +/* == END TODO */ + +/// Calls llvm::createPGOFuncNameVar() with the given function instance's +/// mangled function name. The LLVM API returns an llvm::GlobalVariable +/// containing the function name, with the specific variable name and linkage +/// required by LLVM InstrProf source-based coverage instrumentation. Use +/// `bx.get_pgo_func_name_var()` to ensure the variable is only created once per +/// `Instance`. +pub fn create_pgo_func_name_var<'ctx>(func: &FunctionValue<'ctx>) -> GlobalValue<'ctx> { + let pgo_function_ref = + unsafe { ffi::LLVMRustCoverageCreatePGOFuncNameVar(func.as_value_ref(), func.get_name().as_ptr()) }; + assert!(!pgo_function_ref.is_null()); + unsafe { GlobalValue::new(pgo_function_ref) } +} + +pub fn write_filenames_section_to_buffer<'a>( + filenames: impl IntoIterator, + buffer: &RustString, +) { + let c_str_vec = filenames.into_iter().map(|cstring| cstring.as_ptr()).collect::>(); + unsafe { + ffi::LLVMRustCoverageWriteFilenamesSectionToBuffer(c_str_vec.as_ptr(), c_str_vec.len(), buffer); + } +} +//create params , call fucntion in codegen, print the buffer +pub fn write_mapping_to_buffer( + virtual_file_mapping: Vec, + expressions: Vec, + mapping_regions: Vec, + buffer: &mut RustString, +) { + unsafe { + ffi::LLVMRustCoverageWriteMappingToBuffer( + virtual_file_mapping.as_ptr(), + virtual_file_mapping.len() as c_uint, + expressions.as_ptr(), + expressions.len() as c_uint, + mapping_regions.as_ptr(), + mapping_regions.len() as c_uint, + buffer, + ); + } +} + +pub fn hash_str(strval: &str) -> u64 { + let strval = CString::new(strval).expect("null error converting hashable str to C string"); + unsafe { ffi::LLVMRustCoverageHashCString(strval.as_ptr()) } +} + +pub fn hash_bytes(bytes: Vec) -> u64 { + unsafe { ffi::LLVMRustCoverageHashByteArray(bytes.as_ptr().cast(), bytes.len()) } +} + +pub fn get_mapping_version() -> u32 { + unsafe { ffi::LLVMRustCoverageMappingVersion() } +} + +pub fn save_cov_data_to_mod<'ctx>(module: &Module<'ctx>, cov_data_val: StructValue<'ctx>) { + let covmap_var_name = { + let mut s = RustString::new(); + unsafe { + ffi::LLVMRustCoverageWriteMappingVarNameToString(&mut s); + } + build_string(&mut s).expect("Rust Coverage Mapping var name failed UTF-8 conversion") + }; + + let covmap_section_name = { + let mut s = RustString::new(); + unsafe { + ffi::LLVMRustCoverageWriteMapSectionNameToString(module.as_mut_ptr(), &mut s); + } + build_string(&mut s).expect("Rust Coverage Mapping section name failed UTF-8 conversion") + }; + + let llglobal = module.add_global(cov_data_val.get_type(), None, covmap_var_name.as_str()); + llglobal.set_initializer(&cov_data_val); + llglobal.set_constant(true); + llglobal.set_linkage(Linkage::Private); + llglobal.set_section(Some(&covmap_section_name)); + llglobal.set_alignment(VAR_ALIGN_BYTES); + + // Mark as used to prevent removal by LLVM optimizations + unsafe { + ffi::LLVMRustAppendToUsed(module.as_mut_ptr(), llglobal.as_pointer_value()); + } +} + +pub fn save_func_record_to_mod<'ctx>( + module: &Module<'ctx>, + func_name_hash: u64, + func_record_val: StructValue<'ctx>, + is_used: bool, +) { + // Assign a name to the function record. This is used to merge duplicates. + // + // In LLVM, a "translation unit" (effectively, a `Crate` in Rust) can describe functions that + // are included-but-not-used. If (or when) Rust generates functions that are + // included-but-not-used, note that a dummy description for a function included-but-not-used + // in a Crate can be replaced by full description provided by a different Crate. The two kinds + // of descriptions play distinct roles in LLVM IR; therefore, assign them different names (by + // appending "u" to the end of the function record var name, to prevent `linkonce_odr` merging. + // TODO - investigate removing this (-Corban) + let func_record_var_name = format!("__covrec_{:X}{}", func_name_hash, if is_used { "u" } else { "" }); + + let func_record_section_name = { + let mut s = RustString::new(); + unsafe { + ffi::LLVMRustCoverageWriteFuncSectionNameToString(module.as_mut_ptr(), &mut s); + } + build_string(&mut s).expect("Rust Coverage function record section name failed UTF-8 conversion") + }; + + // Create types + let llglobal = module.add_global(func_record_val.get_type(), None, func_record_var_name.as_str()); + let comdat = module.get_or_insert_comdat(&func_record_var_name); + + // Assign + llglobal.set_initializer(&func_record_val); + llglobal.set_constant(true); + llglobal.set_linkage(Linkage::LinkOnceODR); + llglobal.set_visibility(GlobalVisibility::Hidden); + llglobal.set_section(Some(&func_record_section_name)); + llglobal.set_alignment(VAR_ALIGN_BYTES); + llglobal.set_comdat(comdat); + + // Mark as used to prevent removal by LLVM optimizations + unsafe { + ffi::LLVMRustAppendToUsed(module.as_mut_ptr(), llglobal.as_pointer_value()); + } +} diff --git a/compiler/rustc_llvm_coverage/src/lib.rs b/compiler/rustc_llvm_coverage/src/lib.rs new file mode 100644 index 00000000000..e9325f6bd86 --- /dev/null +++ b/compiler/rustc_llvm_coverage/src/lib.rs @@ -0,0 +1,304 @@ +//! This library provides a Rust interface to LLVM's coverage mapping format. +//! +//! This module exists to provide intuitive and useful abstractions for +//! interacting with LLVM's coverage mapping functions. If you want to +//! interact directly with LLVM, use the [`interfaces`] or [`ffi`] modules. +//! +//! + +pub mod ffi; +pub mod interfaces; +pub mod types; +use interfaces::create_pgo_func_name_var; +use interfaces::*; +use types::*; + +use inkwell::builder::Builder; +use inkwell::context::Context; +use inkwell::intrinsics::Intrinsic; +use inkwell::module::Module; +use inkwell::passes::PassBuilderOptions; +use inkwell::targets::{CodeModel, InitializationConfig, RelocMode, Target, TargetTriple}; +use inkwell::values::FunctionValue; +use inkwell::values::PointerValue; +use inkwell::OptimizationLevel; +use std::ffi::CString; + +/// This represents a coverage mapping header that has been written to a module. +/// It is returned for debugging purposes and use with write_function_record. +pub struct CoverageMappingHeader { + pub mapping_version: u32, + pub filenames: Vec, + pub filenames_hash: u64, + pub encoded_filename_buffer: RustString, +} + +impl CoverageMappingHeader { + pub fn new(filenames: Vec) -> Self { + // Get mapping version from LLVM + let mapping_version = get_mapping_version(); // versions are zero-indexed + assert_eq!(mapping_version, 5, "Only mapping version 6 is supported"); + + // Convert filenames to CStrings + let filenames_cstr = + filenames.clone().into_iter().map(|f| CString::new(f).unwrap()).collect::>(); + let mut encoded_filename_buffer = RustString::new(); + write_filenames_section_to_buffer(&filenames_cstr, &mut encoded_filename_buffer); + + // Calc file hash + let filenames_hash = hash_bytes(encoded_filename_buffer.bytes.borrow().to_vec()); + + CoverageMappingHeader { mapping_version, filenames, filenames_hash, encoded_filename_buffer } + } + + /// filenames: In Coverage Mapping Version > 6, first filename must be the compilation directory + pub fn write_coverage_mapping_header<'ctx>(&self, module: &Module<'ctx>) { + // Get context + let context = module.get_context(); + + // Create mapping header types + let i32_type = context.i32_type(); + let i32_zero = i32_type.const_int(0, false); + let i32_cov_mapping_version = i32_type.const_int(self.mapping_version.into(), false); + let i32_filenames_len = i32_type.const_int(self.encoded_filename_buffer.len() as u64, false); + + // See LLVM Code Coverage Specification for details on this data structure + let cov_mapping_header = context.const_struct( + &[ + // Value 1 : Always zero + i32_zero.into(), + // Value 2 : Len(encoded_filenames) + i32_filenames_len.into(), + // Value 3 : Always zero + i32_zero.into(), + // Value 4 : Mapping version + i32_cov_mapping_version.into(), + ], + // https://github.com/rust-lang/rust/blob/e6707df0de337976dce7577e68fc57adcd5e4842/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs#L301 + false, + ); + + // Create filename value types + let i8_type = context.i8_type(); + let i8_filename_array = i8_type.const_array( + &self + .encoded_filename_buffer + .bytes + .borrow() + .iter() + .map(|byte| i8_type.const_int(*byte as u64, false)) + .collect::>(), + ); + + // Create structure + let coverage_struct = + context.const_struct(&[cov_mapping_header.into(), i8_filename_array.into()], false); + + // Write to module + save_cov_data_to_mod(module, coverage_struct); + } +} + +pub struct FunctionRecord { + pub name: String, + pub name_md5_hash: u64, + pub structural_hash: u64, + pub virtual_file_mapping: Vec, + pub expressions: Vec, + pub mapping_regions: Vec, + pub mapping_buffer: RustString, + + // A.k.a. hash of all filenames in module + pub translation_unit_hash: u64, + pub is_used: bool, +} + +impl FunctionRecord { + /// TODO - Update to use a filename table, like + /// https://github.com/rust-lang/rust/blob/e6707df0de337976dce7577e68fc57adcd5e4842/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs#L155-L194 + pub fn new( + name: String, + structural_hash: u64, + // TODO - better names for these + function_filenames: Vec, + expressions: Vec, + mapping_regions: Vec, + is_used: bool, + + written_mapping_header: &CoverageMappingHeader, + ) -> Self { + let name_md5_hash = hash_str(&name); + + // Get indexes of function filenames in module file list + // TODO - hoist this into rusty + let mut virtual_file_mapping = Vec::new(); + for filename in function_filenames { + let filename_idx = written_mapping_header + .filenames + .iter() + .position(|f| f == &filename) + .expect("Unable to find function filename in module files"); + virtual_file_mapping.push(filename_idx.try_into().unwrap()); + } + + // Write mapping to buffer + let mut mapping_buffer = RustString::new(); + write_mapping_to_buffer( + virtual_file_mapping.clone(), + expressions.clone(), + mapping_regions.clone(), + &mut mapping_buffer, + ); + + FunctionRecord { + name, + name_md5_hash, + structural_hash, + virtual_file_mapping, + expressions, + is_used, + mapping_regions, + mapping_buffer, + translation_unit_hash: written_mapping_header.filenames_hash, + } + } + + pub fn write_to_module<'ctx>(&self, module: &Module<'ctx>) { + // Get context + let context = module.get_context(); + + // Create types + let i64_type = context.i64_type(); + let i32_type = context.i32_type(); + let i8_type = context.i8_type(); + + // Create values + let i64_name_md5_hash = i64_type.const_int(self.name_md5_hash, false); + let i32_mapping_len = i32_type.const_int(self.mapping_buffer.len() as u64, false); + let i64_structural_hash = i64_type.const_int(self.structural_hash, false); + let i64_translation_unit_hash = i64_type.const_int(self.translation_unit_hash, false); + + // Build mapping array + let i8_mapping_array = i8_type.const_array( + &self + .mapping_buffer + .bytes + .borrow() + .iter() + .map(|byte| i8_type.const_int(*byte as u64, false)) + .collect::>(), + ); + + // Create structure + let function_record_struct = context.const_struct( + &[ + i64_name_md5_hash.into(), + i32_mapping_len.into(), + i64_structural_hash.into(), + i64_translation_unit_hash.into(), + i8_mapping_array.into(), + ], + // https://github.com/rust-lang/rust/blob/e6707df0de337976dce7577e68fc57adcd5e4842/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs#L311 + true, + ); + + save_func_record_to_mod(&module, self.name_md5_hash, function_record_struct, self.is_used); + } +} + +/// This pass will not operate unless the module already has intrinsic calls. +/// See [here](https://github.com/llvm/llvm-project/blob/f28c006a5895fc0e329fe15fead81e37457cb1d1/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp#L539-L549) for why. +pub fn run_instrumentation_lowering_pass<'ctx>(module: &Module<'ctx>) { + // Setup + let initialization_config = &InitializationConfig::default(); + inkwell::targets::Target::initialize_all(initialization_config); + + // Architecture Specifics + // Module.set_triple() is required because the pass needs to know it's compiling + // to ELF [here](https://github.com/llvm/llvm-project/blob/cfa30fa4852275eed0c59b81b5d8088d3e55f778/llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp#L1191-L1199). + // TODO - pass this as a param + let triple = TargetTriple::create("x86_64-pc-linux-gnu"); + module.set_triple(&triple); + let target = Target::from_triple(&triple).unwrap(); + let machine = target + .create_target_machine( + &triple, + "generic", + "", + OptimizationLevel::None, + RelocMode::Default, + CodeModel::Default, + ) + .unwrap(); + + // Run pass (uses new pass manager) + // When compiling as IR, run this: + let passes = "instrprof"; + // When compiling as object, run this: + // let passes = "instrprof,asan-module,function(asan)"; + match module.run_passes(passes, &machine, PassBuilderOptions::create()) { + Ok(_) => (), + Err(e) => panic!("Failed to run instrprof pass: {}", e), + } +} + +/// Emits a increment counter call at the current builder position. +/// +/// `pgo_function_var` is a pointer to the function's global name variable, +/// generated from [`create_pgo_func_name_var`]. +/// +/// TODO - verify the correctness of these lifetimes. +pub fn emit_counter_increment<'ink, 'ctx>( + builder: &Builder<'ink>, + module: &Module<'ctx>, + pgo_function_var: &PointerValue<'ink>, + structural_hash: u64, + num_counters: u32, + counter_idx: u64, +) { + let context = module.get_context(); + let increment_intrinsic = Intrinsic::find("llvm.instrprof.increment").unwrap(); + let increment_intrinsic_func = increment_intrinsic.get_declaration(module, &[]).unwrap(); + + // Create types + let i64_type = context.i64_type(); + let i32_type = context.i32_type(); + + let i64_hash = i64_type.const_int(structural_hash, false); + let i32_num_counters = i32_type.const_int(num_counters.into(), false); + let i64_counter_idx = i64_type.const_int(counter_idx, false); + + builder.build_call( + increment_intrinsic_func, + &[(*pgo_function_var).into(), i64_hash.into(), i32_num_counters.into(), i64_counter_idx.into()], + "increment_call", + ); +} + +/// TODO - merge with function from above +pub fn emit_counter_increment_with_function<'ink>( + builder: &Builder<'ink>, + context: &Context, + increment_intrinsic_func: &FunctionValue<'ink>, + pgo_function_var: &PointerValue<'ink>, + structural_hash: u64, + num_counters: u32, + counter_idx: u64, +) { + // Create types + let i64_type = context.i64_type(); + let i32_type = context.i32_type(); + + let i64_hash = i64_type.const_int(structural_hash, false); + let i32_num_counters = i32_type.const_int(num_counters.into(), false); + let i64_counter_idx = i64_type.const_int(counter_idx, false); + + builder.build_call( + *increment_intrinsic_func, + &[(*pgo_function_var).into(), i64_hash.into(), i32_num_counters.into(), i64_counter_idx.into()], + "increment_call", + ); +} + +// TODO +// - investigate codegen diffs for function/function blocks/programs diff --git a/compiler/rustc_llvm_coverage/src/types.rs b/compiler/rustc_llvm_coverage/src/types.rs new file mode 100644 index 00000000000..d663f12d34a --- /dev/null +++ b/compiler/rustc_llvm_coverage/src/types.rs @@ -0,0 +1,329 @@ +#![allow(dead_code, unused_variables)] + +// These data structures provide definitions for the wrapped C LLVM interface. +// +// These type definitions are taken from: +// - [`rustc_codegen_ssa/src/coverageinfo/ffi.rs`](https://github.com/rust-lang/rust/blob/84c898d65adf2f39a5a98507f1fe0ce10a2b8dbc/compiler/rustc_codegen_ssa/src/coverageinfo/ffi.rs#L4-L5) +// - [`rustc_codegen_llvm/src/coverageinfo/ffi.rs`](https://github.com/rust-lang/rust/blob/56278a6e2824acc96b222e5816bf2d74e85dab93/compiler/rustc_codegen_llvm/src/coverageinfo/ffi.rs#L4) +// - [`rustc_middle/src/mir/coverage.rs`](https://github.com/rust-lang/rust/blob/56278a6e2824acc96b222e5816bf2d74e85dab93/compiler/rustc_middle/src/mir/coverage.rs#L9) +// + +use std::cell::RefCell; + +#[repr(C)] +pub struct RustString { + pub bytes: RefCell>, +} + +impl RustString { + pub fn new() -> Self { + Self { bytes: RefCell::new(Vec::new()) } + } + + pub fn len(&self) -> usize { + self.bytes.borrow().len() + } + + pub fn is_empty(&self) -> bool { + self.bytes.borrow().is_empty() + } +} + +#[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub struct CounterId(u32); + +impl CounterId { + pub const START: Self = Self(0); + + pub fn new(value: u32) -> Self { + CounterId(value) + } + + pub fn from_u32(value: u32) -> Self { + CounterId::new(value) + } + + pub fn as_u32(&self) -> u32 { + self.0 + } +} + +#[derive(Hash, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Debug)] +pub struct ExpressionId(u32); + +impl ExpressionId { + pub const START: Self = Self(0); + + pub fn new(value: u32) -> Self { + ExpressionId(value) + } + + pub fn from_u32(value: u32) -> Self { + ExpressionId::new(value) + } + + pub fn as_u32(&self) -> u32 { + self.0 + } +} + +/// Corresponds to enum `llvm::coverage::CounterMappingRegion::RegionKind`. +/// +/// Must match the layout of `LLVMRustCounterMappingRegionKind`. +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub enum RegionKind { + /// A CodeRegion associates some code with a counter + CodeRegion = 0, + + /// An ExpansionRegion represents a file expansion region that associates + /// a source range with the expansion of a virtual source file, such as + /// for a macro instantiation or #include file. + ExpansionRegion = 1, + + /// A SkippedRegion represents a source range with code that was skipped + /// by a preprocessor or similar means. + SkippedRegion = 2, + + /// A GapRegion is like a CodeRegion, but its count is only set as the + /// line execution count when its the only region in the line. + GapRegion = 3, + + /// A BranchRegion represents leaf-level boolean expressions and is + /// associated with two counters, each representing the number of times the + /// expression evaluates to true or false. + BranchRegion = 4, +} + +/// This struct provides LLVM's representation of a "CoverageMappingRegion", encoded into the +/// coverage map, in accordance with the +/// [LLVM Code Coverage Mapping Format](https://github.com/rust-lang/llvm-project/blob/rustc/13.0-2021-09-30/llvm/docs/CoverageMappingFormat.rst#llvm-code-coverage-mapping-format). +/// The struct composes fields representing the `Counter` type and value(s) (injected counter +/// ID, or expression type and operands), the source file (an indirect index into a "filenames +/// array", encoded separately), and source location (start and end positions of the represented +/// code region). +/// +/// Corresponds to struct `llvm::coverage::CounterMappingRegion`. +/// +/// Must match the layout of `LLVMRustCounterMappingRegion`. +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub struct CounterMappingRegion { + /// The counter type and type-dependent counter data, if any. + pub counter: Counter, + + /// If the `RegionKind` is a `BranchRegion`, this represents the counter + /// for the false branch of the region. + pub false_counter: Counter, + + /// An indirect reference to the source filename. In the LLVM Coverage Mapping Format, the + /// file_id is an index into a function-specific `virtual_file_mapping` array of indexes + /// that, in turn, are used to look up the filename for this region. + file_id: u32, + + /// If the `RegionKind` is an `ExpansionRegion`, the `expanded_file_id` can be used to find + /// the mapping regions created as a result of macro expansion, by checking if their file id + /// matches the expanded file id. + expanded_file_id: u32, + + /// 1-based starting line of the mapping region. + start_line: u32, + + /// 1-based starting column of the mapping region. + start_col: u32, + + /// 1-based ending line of the mapping region. + end_line: u32, + + /// 1-based ending column of the mapping region. If the high bit is set, the current + /// mapping region is a gap area. + end_col: u32, + + pub kind: RegionKind, +} + +impl CounterMappingRegion { + pub fn code_region( + counter: Counter, + file_id: u32, + start_line: u32, + start_col: u32, + end_line: u32, + end_col: u32, + ) -> Self { + Self { + counter, + false_counter: Counter::ZERO, + file_id, + expanded_file_id: 0, + start_line, + start_col, + end_line, + end_col, + kind: RegionKind::CodeRegion, + } + } + + // This function might be used in the future; the LLVM API is still evolving, as is coverage + // support. + // #[allow(dead_code)] + pub fn branch_region( + counter: Counter, + false_counter: Counter, + file_id: u32, + start_line: u32, + start_col: u32, + end_line: u32, + end_col: u32, + ) -> Self { + Self { + counter, + false_counter, + file_id, + expanded_file_id: 0, + start_line, + start_col, + end_line, + end_col, + kind: RegionKind::BranchRegion, + } + } + + // This function might be used in the future; the LLVM API is still evolving, as is coverage + // support. + #[allow(dead_code)] + pub(crate) fn expansion_region( + file_id: u32, + expanded_file_id: u32, + start_line: u32, + start_col: u32, + end_line: u32, + end_col: u32, + ) -> Self { + Self { + counter: Counter::ZERO, + false_counter: Counter::ZERO, + file_id, + expanded_file_id, + start_line, + start_col, + end_line, + end_col, + kind: RegionKind::ExpansionRegion, + } + } + + // This function might be used in the future; the LLVM API is still evolving, as is coverage + // support. + #[allow(dead_code)] + pub(crate) fn skipped_region( + file_id: u32, + start_line: u32, + start_col: u32, + end_line: u32, + end_col: u32, + ) -> Self { + Self { + counter: Counter::ZERO, + false_counter: Counter::ZERO, + file_id, + expanded_file_id: 0, + start_line, + start_col, + end_line, + end_col, + kind: RegionKind::SkippedRegion, + } + } + + // This function might be used in the future; the LLVM API is still evolving, as is coverage + // support. + #[allow(dead_code)] + pub(crate) fn gap_region( + counter: Counter, + file_id: u32, + start_line: u32, + start_col: u32, + end_line: u32, + end_col: u32, + ) -> Self { + Self { + counter, + false_counter: Counter::ZERO, + file_id, + expanded_file_id: 0, + start_line, + start_col, + end_line, + end_col: (1_u32 << 31) | end_col, + kind: RegionKind::GapRegion, + } + } +} + +/// Aligns with [llvm::coverage::Counter::CounterKind](https://github.com/rust-lang/llvm-project/blob/rustc/13.0-2021-09-30/llvm/include/llvm/ProfileData/Coverage/CoverageMapping.h#L95) +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub enum CounterKind { + Zero = 0, + CounterValueReference = 1, + Expression = 2, +} + +/// A reference to an instance of an abstract "counter" that will yield a value in a coverage +/// report. Note that `id` has different interpretations, depending on the `kind`: +/// * For `CounterKind::Zero`, `id` is assumed to be `0` +/// * For `CounterKind::CounterValueReference`, `id` matches the `counter_id` of the injected +/// instrumentation counter (the `index` argument to the LLVM intrinsic +/// `instrprof.increment()`) +/// * For `CounterKind::Expression`, `id` is the index into the coverage map's array of +/// counter expressions. +/// Aligns with [llvm::coverage::Counter](https://github.com/rust-lang/llvm-project/blob/rustc/13.0-2021-09-30/llvm/include/llvm/ProfileData/Coverage/CoverageMapping.h#L102-L103) +/// Important: The Rust struct layout (order and types of fields) must match its C++ counterpart. +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub struct Counter { + // Important: The layout (order and types of fields) must match its C++ counterpart. + pub kind: CounterKind, + pub id: u32, +} + +impl Counter { + /// A `Counter` of kind `Zero`. For this counter kind, the `id` is not used. + pub const ZERO: Self = Self { kind: CounterKind::Zero, id: 0 }; + + /// Constructs a new `Counter` of kind `CounterValueReference`. + pub fn counter_value_reference(counter_id: CounterId) -> Self { + Self { kind: CounterKind::CounterValueReference, id: counter_id.as_u32() } + } + + /// Constructs a new `Counter` of kind `Expression`. + pub fn expression(expression_id: ExpressionId) -> Self { + Self { kind: CounterKind::Expression, id: expression_id.as_u32() } + } +} + +/// Aligns with [llvm::coverage::CounterExpression::ExprKind](https://github.com/rust-lang/llvm-project/blob/rustc/13.0-2021-09-30/llvm/include/llvm/ProfileData/Coverage/CoverageMapping.h#L150) +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub enum ExprKind { + Subtract = 0, + Add = 1, +} + +/// Aligns with [llvm::coverage::CounterExpression](https://github.com/rust-lang/llvm-project/blob/rustc/13.0-2021-09-30/llvm/include/llvm/ProfileData/Coverage/CoverageMapping.h#L151-L152) +/// Important: The Rust struct layout (order and types of fields) must match its C++ +/// counterpart. +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub struct CounterExpression { + pub kind: ExprKind, + pub lhs: Counter, + pub rhs: Counter, +} + +impl CounterExpression { + pub fn new(lhs: Counter, kind: ExprKind, rhs: Counter) -> Self { + Self { kind, lhs, rhs } + } +} diff --git a/examples/hello_world.st b/examples/hello_world.st index 3ca0c54829e..5d157841061 100644 --- a/examples/hello_world.st +++ b/examples/hello_world.st @@ -6,5 +6,88 @@ END_VAR END_FUNCTION FUNCTION main : DINT + VAR + i : INT; + j : INT; + END_VAR puts('hello, world!$N'); + + if (1) then + puts('true$N'); + if (0) then + puts('truetrue$N'); + elsif (0) then + puts('truefalse$N'); + else + puts('truefalse$N'); + end_if; + puts('true$N'); + else + puts('false$N'); + end_if; + + CASE 1 OF + 1: + puts('Case 1$N'); + 2: + puts('Case 2$N'); + 3: + puts('Case 3$N'); + ELSE + puts('Default case$N'); + END_CASE; + + CASE 1 OF + 2: + puts('Case 2$N'); + ELSE + puts('Default case$N'); + END_CASE; + + if (0) then + puts('true$N'); + elsif (0) then + puts('false$N'); + elsif (1) then + if (1) then + puts('true$N'); + elsif (0) then + puts('false$N'); + else + puts('false$N'); + end_if + puts('true$N'); + else + // puts('false$N'); + end_if; + + FOR i := 1 TO 10 DO + FOR j := 2 TO 3 DO + puts('This is an inner loop iteration'); + END_FOR; + puts('This is a loop iteration $N'); + END_FOR; + + i := 0; + j := 0; + + WHILE i < 10 DO + puts('This is while loop iteration $N'); + i := i + 1; + + + WHILE j < 1 DO + puts('This is an inner while loop iteration $N'); + j := j + 1; + END_WHILE; + + END_WHILE; + + i := 0; + + REPEAT + puts('This is a repeat loop iteration $N'); + i := i + 1; + UNTIL i >= 10 END_REPEAT; + END_FUNCTION diff --git a/scripts/build.sh b/scripts/build.sh index 536c8389400..3a7ec3bfb58 100755 --- a/scripts/build.sh +++ b/scripts/build.sh @@ -5,12 +5,12 @@ vendor=0 offline=0 check=0 check_style=0 -build=0 +build=1 doc=0 test=0 coverage=0 release=0 -debug=0 +debug=1 container=0 assume_linux=0 junit=0 @@ -18,6 +18,8 @@ package=0 target="" CONTAINER_NAME='rust-llvm' +# There may be a better way to set these (i.e. in Dockerfile) +export PATH="$PATH:/usr/lib/llvm-14/bin" source "${BASH_SOURCE%/*}/common.sh" @@ -64,6 +66,11 @@ function run_build() { # Run cargo build with release or debug flags echo "Build starting" echo "-----------------------------------" +# # export REAL_LIBRARY_PATH_VAR="/usr/lib/llvm-14/lib" +# # export PATH="$PATH:/usr/lib/llvm-14/bin" +# cmd="REAL_LIBRARY_PATH_VAR=/usr/lib/llvm-14/lib PATH=$PATH:/usr/lib/llvm-14/bin cargo build -p rustc_llvm_coverage $CARGO_OPTIONS " +# eval "$cmd" + cmd="cargo build $CARGO_OPTIONS " log "Running $cmd" eval "$cmd" @@ -113,6 +120,25 @@ function run_std_build() { fi } +# # Builds the LLVM coverage wrapper functions +# function run_build_llvm_wrappers() { +# CARGO_OPTIONS=$(set_cargo_options) + +# # Run cargo build with release or debug flags +# echo "Build starting" +# echo "-----------------------------------" +# cmd="cargo build $CARGO_OPTIONS -p " +# log "Running $cmd" +# eval "$cmd" +# echo "-----------------------------------" +# if [[ ${PIPESTATUS[0]} -ne 0 ]]; then +# echo "Build failed" +# exit 1 +# else +# echo "Build done" +# fi +# } + function run_check() { CARGO_OPTIONS=$(set_cargo_options) log "Running cargo check" @@ -453,6 +479,61 @@ fi if [[ $build -ne 0 ]]; then run_build + # Test a program + # echo "-----------------------------------" + # echo "Running on example program:" + # # ./target/debug/plc --ir ./examples/simple_program.st + # # ./target/debug/plc --ir hello_world.st + + # # export ASAN_OPTIONS=detect_odr_violation=0 + # #./target/debug/plc --ir ./examples/hello_world.st + # # echo "-----------------------------------" + # # cat ./hello_world.st.ll + # echo "-----------------------------------" + + # # Cleanup prof + # PROFRAW=./default.profraw + # PROFEXPORT=./default.profexport + # rm -f $PROFRAW $PROFEXPORT + # rm -f hello_world.st.out + + # # # Compile + # echo "Compiling hello_world.st.ll to hello_world.st.out" + # clang++-14 -fprofile-instr-generate -fcoverage-mapping -fsanitize=address -Wl,-u,__llvm_profile_runtime -O0 hello_world.st.ll wrapper.cpp -o hello_world.st.out + # echo "Done!" + # echo "-----------------------------------" + + # # # Run + # echo "Running program" + # ./hello_world.st.out + # echo "Done!" + # echo "-----------------------------------" + + # # Show profdata + # echo "Showing profdata" + # llvm-profdata-14 show -all-functions $PROFRAW + # echo "-----------------------------------" + + # # Generate profdata + # echo "Generating profdata" + # llvm-profdata-14 merge $PROFRAW -o $PROFEXPORT + # echo "-----------------------------------" + + # # Generate report + # echo "Generating profdata" + # llvm-cov-14 report -instr-profile=$PROFEXPORT hello_world.st.out + # echo "-----------------------------------" + + # # Show coverage + # echo "Coverage" + # llvm-cov-14 show ./hello_world.st.out -instr-profile=$PROFEXPORT + # llvm-cov-14 show ./hello_world.st.out -instr-profile=$PROFEXPORT -show-regions -show-branches=count --show-expansions + # echo "-----------------------------------" + + # echo "Export" + # # llvm-cov-14 export -instr-profile=$PROFEXPORT hello_world.st.out + # echo "-----------------------------------" + #Build the standard functions run_std_build fi diff --git a/scripts/common.sh b/scripts/common.sh index ec6abae1fc2..d9e122f6511 100755 --- a/scripts/common.sh +++ b/scripts/common.sh @@ -1,6 +1,6 @@ #!/bin/bash -debug=0 +debug=1 function log() { if [[ $debug -ne 0 ]]; then >&2 echo "$1" diff --git a/scripts/debug/decode_filenames.py b/scripts/debug/decode_filenames.py new file mode 100755 index 00000000000..c9cef59d48f --- /dev/null +++ b/scripts/debug/decode_filenames.py @@ -0,0 +1,37 @@ +#!/usr/bin/python3 + +""" +Useful for decoding filenames in generated IR data. + +WARNING: This script assumes all LEB128 values are one byte only! +""" + +import zlib +import argparse + +from parse import parse_llvm_bytestring, parse_llvm_string_to_list, parse_hex_string + +# This should be the encoded data under `__llvm_prf_nm` +parser = argparse.ArgumentParser() +parser.add_argument("encoded", help="encoded data under `__llvm_prf_nm`") +args = parser.parse_args() + +encoded = args.encoded +decoded = parse_llvm_bytestring(encoded) + +# Take off the headers +num_files = decoded.pop(0) +len_uncompressed = decoded.pop(0) +len_compressed = decoded.pop(0) +assert(len_compressed == len(decoded)) + +# Decompress and separate the filenames +decoded_filenames = zlib.decompress(bytes(decoded)) +assert(len(decoded_filenames) == len_uncompressed) +filenames = parse_llvm_string_to_list(decoded_filenames) + +# Display +print(f'Files: {num_files}') +print(f'Len(uncompressed): {len_uncompressed}') +print(f'Len(compressed): {len_compressed}') +print(f'Filenames: {", ".join([filename.decode() for filename in filenames])}') diff --git a/scripts/debug/decode_pgo.py b/scripts/debug/decode_pgo.py new file mode 100755 index 00000000000..a73b848bc35 --- /dev/null +++ b/scripts/debug/decode_pgo.py @@ -0,0 +1,35 @@ +#!/usr/bin/python3 + +""" +Useful for decoding function names in generated IR data. + +WARNING: This script assumes all LEB128 values are one byte only! +""" + +import zlib +import argparse + +from parse import parse_llvm_bytestring + +# This should be the encoded data under `__llvm_coverage_mapping` +parser = argparse.ArgumentParser() +parser.add_argument("encoded", help="encoded data under `__llvm_coverage_mapping`") +args = parser.parse_args() + +encoded = args.encoded +decoded = parse_llvm_bytestring(encoded) + +# Take off the headers +len_uncompressed = decoded.pop(0) +len_compressed = decoded.pop(0) +assert(len_compressed == len(decoded)) + +# Decompress and separate the function names +decompressed_function_names = zlib.decompress(bytes(decoded)) +assert(len_uncompressed == len(decompressed_function_names)) +function_names = decompressed_function_names.split(b"\x01") + +# Display +print(f'Len(uncompressed): {len_uncompressed}') +print(f'Len(compressed): {len_compressed}') +print(f'Function names: {", ".join([function_name.decode() for function_name in function_names])}') diff --git a/scripts/debug/parse.py b/scripts/debug/parse.py new file mode 100644 index 00000000000..b8233216ac7 --- /dev/null +++ b/scripts/debug/parse.py @@ -0,0 +1,38 @@ +def parse_llvm_bytestring(encoded: str): + """ + Parse strings formatted like: + "\04\0Cx\DA\CBM\CC\CC\03\00\04\1B\01\A6" + """ + decoded = [] + while(encoded): + # \ indicates next two chars are hex + if encoded[0] == '\\': + decoded.append(int(encoded[1:3], 16)) + encoded = encoded[3:] # skip the / and the two hex letters + + # ASCII letter has the value + else: + decoded.append(ord(encoded[0])) + encoded = encoded[1:] + + return decoded + +def parse_hex_string(hex_string): + """ + Parse strings formatted like: + 0011223344.. + """ + return [int(hex_string[i:i+2], 16) for i in range(0, len(hex_string), 2)] + +def parse_llvm_string_to_list(packed_string): + """ + Unpack multiple strings formatted like: + ... + """ + values = [] + while (packed_string): + next_string_length = packed_string[0] + values.append(packed_string[1:next_string_length+1]) + packed_string = packed_string[next_string_length+1:] + + return values diff --git a/src/codegen.rs b/src/codegen.rs index fb9f5ce00e7..12f5afa143d 100644 --- a/src/codegen.rs +++ b/src/codegen.rs @@ -1,6 +1,7 @@ // Copyright (c) 2020 Ghaith Hachem and Mathias Rieder use std::{ cell::RefCell, + env, fs, ops::Deref, path::{Path, PathBuf}, }; @@ -35,12 +36,15 @@ use inkwell::{ passes::PassBuilderOptions, targets::{CodeModel, FileType, InitializationConfig, RelocMode}, }; +use instrument::CoverageInstrumentationBuilder; use plc_ast::ast::{CompilationUnit, LinkageType}; use plc_diagnostics::diagnostics::Diagnostic; use plc_source::source_location::SourceLocation; +use rustc_llvm_coverage::types::CounterMappingRegion; mod debug; pub(crate) mod generators; +mod instrument; mod llvm_index; mod llvm_typesystem; #[cfg(test)] @@ -70,7 +74,8 @@ pub struct CodeGen<'ink> { pub module: Module<'ink>, /// the debugging module creates debug information at appropriate locations pub debug: DebugBuilderEnum<'ink>, - + /// the instrumentation builder (possibly later hoisted out of the codegen struct) + pub instrument: Option>, pub module_location: String, } @@ -94,7 +99,20 @@ impl<'ink> CodeGen<'ink> { let module = context.create_module(module_location); module.set_source_file_name(module_location); let debug = debug::DebugBuilderEnum::new(context, &module, root, optimization_level, debug_level); - CodeGen { module, debug, module_location: module_location.to_string() } + + // TODO - disable instr here + let current_dir = env::current_dir().expect("Failed to get current directory"); + + let filenames = vec![current_dir.to_str().unwrap().to_string(), module_location.to_string()]; + let instr_builder = instrument::CoverageInstrumentationBuilder::new(context, filenames); + // instrument::CoverageInstrumentationBuilder::new(/*context, */ &module, filenames); + + CodeGen { + module, + debug, + instrument: Some(instr_builder), + module_location: module_location.to_string(), + } } pub fn generate_llvm_index( @@ -183,12 +201,19 @@ impl<'ink> CodeGen<'ink> { index.associate_utf16_literal(literal, literal_variable); } + // Add the increment intrinsic to the index + if let Some(instr_builder) = &mut self.instrument { + let increment_function = instr_builder.get_increment_function(&self.module); + let increment_function_name = increment_function.get_name().to_str().unwrap(); + index.associate_implementation(increment_function_name, increment_function); + } + Ok(index) } /// generates all TYPEs, GLOBAL-sections and POUs of the given CompilationUnit pub fn generate( - self, + mut self, context: &'ink CodegenContext, unit: &CompilationUnit, annotations: &AstAnnotations, @@ -199,12 +224,27 @@ impl<'ink> CodeGen<'ink> { let llvm = Llvm::new(context, context.create_builder()); let pou_generator = PouGenerator::new(llvm, global_index, annotations, llvm_index); + if let Some(instr_builder) = &mut self.instrument { + // Mark functions as sanitize + instr_builder.sanitize_functions(unit, llvm_index, &self.module); + + // Generate mapping header + instr_builder.initialize(&self.module); + + instr_builder.create_function_records(unit, llvm_index, &self.module); + } + //Generate the POU stubs in the first go to make sure they can be referenced. for implementation in &unit.implementations { //Don't generate external or generic functions if let Some(entry) = global_index.find_pou(implementation.name.as_str()) { if !entry.is_generic() && entry.get_linkage() != &LinkageType::External { - pou_generator.generate_implementation(implementation, &self.debug)?; + pou_generator.generate_implementation( + implementation, + &self.debug, + &self.module, + &self.instrument, + )?; } } } @@ -212,6 +252,11 @@ impl<'ink> CodeGen<'ink> { self.debug.finalize(); log::debug!("{}", self.module.to_string()); + // Run the pass + if let Some(instr_builder) = &mut self.instrument { + instr_builder.finalize(&self.module); + } + #[cfg(feature = "verify")] { self.module diff --git a/src/codegen/generators/pou_generator.rs b/src/codegen/generators/pou_generator.rs index e977f45acec..f5ab292c0ed 100644 --- a/src/codegen/generators/pou_generator.rs +++ b/src/codegen/generators/pou_generator.rs @@ -10,6 +10,7 @@ use super::{ use crate::{ codegen::{ debug::{Debug, DebugBuilderEnum}, + instrument::CoverageInstrumentationBuilder, llvm_index::LlvmTypedIndex, }, index::{self, ImplementationType}, @@ -28,6 +29,7 @@ use crate::index::{ImplementationIndexEntry, VariableIndexEntry}; use crate::index::Index; use indexmap::{IndexMap, IndexSet}; use inkwell::{ + intrinsics::Intrinsic, module::Module, types::{BasicMetadataTypeEnum, BasicTypeEnum, FunctionType}, values::{BasicValue, BasicValueEnum, FunctionValue}, @@ -277,10 +279,12 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> { } /// generates a function for the given pou - pub fn generate_implementation( + pub fn generate_implementation<'ctx>( &self, implementation: &Implementation, debug: &DebugBuilderEnum<'ink>, + module: &Module<'ctx>, + instrumentation: &Option>, ) -> Result<(), Diagnostic> { let context = self.llvm.context; let mut local_index = LlvmTypedIndex::create_child(self.llvm_index); @@ -305,7 +309,7 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> { .unwrap(); debug.set_debug_location(&self.llvm, ¤t_function, line, column); - //generate the body + //generate the body - "entry" block marks the beginning of the function let block = context.append_basic_block(current_function, "entry"); //Create all labels this function will have @@ -318,6 +322,8 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> { self.llvm.builder.position_at_end(block); blocks.insert("entry".into(), block); + // TODO - don't hardcode this name + let increment_function = local_index.find_associated_implementation("llvm.instrprof.increment"); let function_context = FunctionContext { linking_context: self.index.find_implementation_by_name(&implementation.name).ok_or_else( || { @@ -329,6 +335,7 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> { )?, function: current_function, blocks, + increment_function, }; let mut param_index = 0; @@ -386,7 +393,19 @@ impl<'ink, 'cg> PouGenerator<'ink, 'cg> { &local_index, &function_context, debug, + instrumentation, ); + + // Emit instrumentation before function body + if let Some(instr_builder) = instrumentation { + instr_builder.emit_function_increment( + &self.llvm.builder, + module, + current_function.get_name().to_str().unwrap(), + 0, + ); + } + statement_gen.generate_body(&implementation.statements)?; statement_gen.generate_return_statement()?; } diff --git a/src/codegen/generators/statement_generator.rs b/src/codegen/generators/statement_generator.rs index 0a65c636f72..daa25019a4a 100644 --- a/src/codegen/generators/statement_generator.rs +++ b/src/codegen/generators/statement_generator.rs @@ -6,6 +6,7 @@ use super::{ llvm::Llvm, }; use crate::{ + codegen::instrument::CoverageInstrumentationBuilder, codegen::{debug::Debug, llvm_typesystem::cast_if_needed}, codegen::{debug::DebugBuilderEnum, LlvmTypedIndex}, index::{ImplementationIndexEntry, Index}, @@ -16,6 +17,7 @@ use inkwell::{ basic_block::BasicBlock, builder::Builder, context::Context, + module::Module, values::{BasicValueEnum, FunctionValue, PointerValue}, }; use plc_ast::{ @@ -36,6 +38,9 @@ pub struct FunctionContext<'ink, 'b> { pub function: FunctionValue<'ink>, /// The blocks/labels this function can use pub blocks: HashMap>, + // TODO - there may be a better spot for this, but it works for now + /// Increment intrinsic + pub increment_function: Option>, } /// the StatementCodeGenerator is used to generate statements (For, If, etc.) or expressions (references, literals, etc.) @@ -55,6 +60,7 @@ pub struct StatementCodeGenerator<'a, 'b> { pub current_loop_continue: Option>, pub debug: &'b DebugBuilderEnum<'a>, + pub instrumentation: &'b Option>, } impl<'a, 'b> StatementCodeGenerator<'a, 'b> { @@ -66,6 +72,7 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { llvm_index: &'b LlvmTypedIndex<'a>, linking_context: &'b FunctionContext<'a, 'b>, debug: &'b DebugBuilderEnum<'a>, + instrumentation: &'b Option>, ) -> StatementCodeGenerator<'a, 'b> { StatementCodeGenerator { llvm, @@ -78,10 +85,11 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { current_loop_exit: None, current_loop_continue: None, debug, + instrumentation, } } - /// convinience method to create an expression-generator + /// convenience method to create an expression-generator fn create_expr_generator(&'a self) -> ExpressionCodeGenerator<'a, 'b> { ExpressionCodeGenerator::new( self.llvm, @@ -420,6 +428,20 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { load_suffix: self.load_suffix.clone(), ..*self }; + + // Generate counter increment + if let Some(instr_builder) = self.instrumentation { + if let Some(first_ast) = body.first() { + instr_builder.emit_branch_increment( + builder, + context, + &self.get_increment_function(), + current_function.get_name().to_str().unwrap(), + first_ast.id, + ); + } + } + body_generator.generate_body(body)?; builder.build_unconditional_branch(increment_block); @@ -455,6 +477,17 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { //Continue builder.position_at_end(continue_block); + // Increment false counter + if let Some(instr_builder) = self.instrumentation { + instr_builder.emit_branch_increment( + builder, + context, + &self.get_increment_function(), + current_function.get_name().to_str().unwrap(), + end.id, + ); + } + Ok(()) } @@ -557,12 +590,40 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { } //generate the case's body builder.position_at_end(case_block); + + // Generate counter increment + if let Some(instr_builder) = self.instrumentation { + if let Some(first_ast) = conditional_block.body.first() { + instr_builder.emit_branch_increment( + builder, + context, + &self.get_increment_function(), + current_function.get_name().to_str().unwrap(), + first_ast.id, + ); + } + } + self.generate_body(&conditional_block.body)?; // skiop all other case-bodies builder.build_unconditional_branch(continue_block); } // current-else is the last else-block generated by the range-expressions builder.position_at_end(current_else_block); + + // Generate counter increment + if let Some(instr_builder) = self.instrumentation { + if let Some(first_ast) = else_body.first() { + instr_builder.emit_branch_increment( + builder, + context, + &self.get_increment_function(), + current_function.get_name().to_str().unwrap(), + first_ast.id, + ); + } + } + self.generate_body(else_body)?; builder.build_unconditional_branch(continue_block); continue_block.move_after(current_else_block).expect(INTERNAL_LLVM_ERROR); @@ -703,12 +764,35 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { load_suffix: self.load_suffix.clone(), ..*self }; + if let Some(instr_builder) = self.instrumentation { + if let Some(first_ast) = body.first() { + instr_builder.emit_branch_increment( + builder, + context, + &self.get_increment_function(), + current_function.get_name().to_str().unwrap(), + first_ast.id, + ); + } + } body_generator.generate_body(body)?; //Loop back builder.build_unconditional_branch(condition_check); //Continue builder.position_at_end(continue_block); + + // Increment false counter + if let Some(instr_builder) = self.instrumentation { + instr_builder.emit_branch_increment( + builder, + context, + &self.get_increment_function(), + current_function.get_name().to_str().unwrap(), + condition.id, + ); + } + Ok((condition_check, while_body)) } @@ -758,6 +842,20 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { //Generate if statement content builder.position_at_end(conditional_block); + + // Generate counter increment + if let Some(instr_builder) = self.instrumentation { + if let Some(first_ast) = block.body.first() { + instr_builder.emit_branch_increment( + builder, + context, + &self.get_increment_function(), + current_function.get_name().to_str().unwrap(), + first_ast.id, + ); + } + } + self.generate_body(&block.body)?; builder.build_unconditional_branch(continue_block); } @@ -765,6 +863,21 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { if let Some(else_block) = else_block { builder.position_at_end(else_block); + // TODO - if else block[0] exists, then generate a ctr increment real quick + + // Generate counter increment + if let Some(instr_builder) = self.instrumentation { + if let Some(first_ast) = else_body.first() { + instr_builder.emit_branch_increment( + builder, + context, + &self.get_increment_function(), + current_function.get_name().to_str().unwrap(), + first_ast.id, + ); + } + } + self.generate_body(else_body)?; builder.build_unconditional_branch(continue_block); } @@ -841,6 +954,10 @@ impl<'a, 'b> StatementCodeGenerator<'a, 'b> { fn get_llvm_deps(&self) -> (&Builder, FunctionValue, &Context) { (&self.llvm.builder, self.function_context.function, self.llvm.context) } + + fn get_increment_function(&self) -> FunctionValue { + self.function_context.increment_function.expect("Increment function not set") + } } /// when generating an assignment to a direct-access (e.g. a.b.c.%W3.%X2 := 2;) diff --git a/src/codegen/instrument.rs b/src/codegen/instrument.rs new file mode 100644 index 00000000000..6f180856e45 --- /dev/null +++ b/src/codegen/instrument.rs @@ -0,0 +1,534 @@ +use super::LlvmTypedIndex; +use inkwell::attributes::{Attribute, AttributeLoc}; +use inkwell::builder::Builder; +use inkwell::context::Context; +use inkwell::intrinsics::Intrinsic; +use inkwell::module::Module; +use inkwell::values::{FunctionValue, GlobalValue}; +use plc_ast::ast::{AstId, AstNode, AstStatement, CompilationUnit, Implementation, LinkageType}; +use plc_ast::control_statements::AstControlStatement; +use plc_source::source_location::{CodeSpan, SourceLocation}; +use rustc_llvm_coverage::types::{ + Counter, CounterExpression, CounterId, CounterMappingRegion, ExprKind, ExpressionId, +}; +use rustc_llvm_coverage::*; +use std::collections::HashMap; +use std::ffi::CString; + +pub struct CoverageInstrumentationBuilder<'ink> { + context: &'ink Context, + // module: &'ink Module<'ink>, + files: Vec, + cov_mapping_header: Option, + function_pgos: HashMap)>, + // TODO - better counter datastructures + // ast_counter_lookup: + // - if statements: map first body block -> true branch counter + // - case statements: map first body block -> true branch counter + // - for statements: map first body block -> true branch counter AND end block -> false branch counter + // - while statements: map first body block -> true branch counter AND condition block -> false branch counter + // - repeat statements: map first body block -> true branch counter AND condition block -> false branch counter + ast_counter_lookup: HashMap, +} + +/// Manages the creation of mapping regions for a given function +#[derive(Debug)] +struct MappingRegionGenerator { + pub mapping_regions: Vec, + pub expressions: Vec, + next_counter_id: u32, + next_expression_id: u32, + file_id: u32, +} + +impl MappingRegionGenerator { + pub fn new(file_id: u32) -> Self { + Self { + mapping_regions: Vec::new(), + expressions: Vec::new(), + next_counter_id: 0, + next_expression_id: 0, + file_id, + } + } + + /// Adds to internal index and returns for convenience + pub fn add_code_mapping_region(&mut self, source: &SourceLocation) -> CounterMappingRegion { + let (start_line, start_col, end_line, end_col) = source.get_start_end(); + + let counter_id = self.next_counter_id; + let counter = Counter::counter_value_reference(CounterId::new(counter_id)); + self.next_counter_id += 1; + + let mapping_region = CounterMappingRegion::code_region( + counter, + self.file_id, + start_line.try_into().unwrap(), + start_col.try_into().unwrap(), + end_line.try_into().unwrap(), + end_col.try_into().unwrap(), + ); + self.mapping_regions.push(mapping_region.clone()); + + mapping_region + } + + // TODO - consolidate the two below functions + /// Adds to internal index and returns for convenience + /// Specific to if statements and case statements + pub fn add_branch_mapping_region( + &mut self, + source: &SourceLocation, + last_false_counter: Counter, + ) -> CounterMappingRegion { + let (start_line, start_col, end_line, end_col) = source.get_start_end(); + + // Counts branch executions + let counter_id = self.next_counter_id; + let counter = Counter::counter_value_reference(CounterId::new(counter_id)); + self.next_counter_id += 1; + + // Count the branch skips (when cond evalutes to false using a_{n-1} - Counter) + let false_counter_id = self.next_expression_id; + let false_counter = Counter::expression(ExpressionId::new(false_counter_id)); + let false_counter_expression = + CounterExpression::new(last_false_counter, ExprKind::Subtract, counter); + self.expressions.push(false_counter_expression); + self.next_expression_id += 1; + + let mapping_region = CounterMappingRegion::branch_region( + counter, + false_counter, + self.file_id, + start_line.try_into().unwrap(), + start_col.try_into().unwrap(), + end_line.try_into().unwrap(), + end_col.try_into().unwrap(), + ); + self.mapping_regions.push(mapping_region.clone()); + + // Return the index of the counter id added + mapping_region + } + + pub fn add_loop_branch_mapping_region( + &mut self, + source: &SourceLocation, + false_counter: Counter, + ) -> CounterMappingRegion { + let (start_line, start_col, end_line, end_col) = source.get_start_end(); + + // Counts branch executions + let counter_id = self.next_counter_id; + let counter = Counter::counter_value_reference(CounterId::new(counter_id)); + self.next_counter_id += 1; + + let mapping_region = CounterMappingRegion::branch_region( + counter, + false_counter, + self.file_id, + start_line.try_into().unwrap(), + start_col.try_into().unwrap(), + end_line.try_into().unwrap(), + end_col.try_into().unwrap(), + ); + self.mapping_regions.push(mapping_region.clone()); + + // Return the index of the counter id added + mapping_region + } +} + +impl<'ink> CoverageInstrumentationBuilder<'ink> { + pub fn new(context: &'ink Context, /* module: &'ink Module<'ink>,*/ files: Vec) -> Self { + Self { + context, + // module, + files, + cov_mapping_header: None, + function_pgos: HashMap::new(), + ast_counter_lookup: HashMap::new(), + } + } + + pub fn initialize(&mut self, module: &Module<'ink>) { + let cov_mapping_header = CoverageMappingHeader::new(self.files.clone()); + cov_mapping_header.write_coverage_mapping_header(module); + self.cov_mapping_header = Some(cov_mapping_header); + } + + pub fn create_function_records( + &mut self, + unit: &CompilationUnit, + llvm_index: &LlvmTypedIndex, + module: &Module<'ink>, + ) { + // Keep records + let mut function_records = Vec::new(); + + // Loop through functions in AST, create function records + for implementation in &unit.implementations { + // Skip non-internal functions (external links + built-ins) + if implementation.linkage != LinkageType::Internal { + continue; + } + // Skip no-definition functions + // TODO - investigate which functions don't have definitions and why + if module.get_function(&implementation.name).is_none() { + println!("Skipping undefined function: {}", &implementation.name); + continue; + } + + let func = self.generate_function_record(implementation); + func.write_to_module(module); + + function_records.push(func); + } + + // Loop through LLVM definitions, create PGO vars + for function_record in function_records { + let func_name = function_record.name.clone(); + + let func = module + .get_function(&func_name) + .expect(&format!("Function not found in module: {}", func_name)); + + let func_pgo = rustc_llvm_coverage::interfaces::create_pgo_func_name_var(&func); + + &self.function_pgos.insert(func_name, (function_record, func_pgo)); + } + } + + pub fn emit_function_increment<'ctx>( + &self, + builder: &Builder<'ink>, + module: &Module<'ctx>, + func_name: &str, + counter_index: u64, + ) { + let (func_record, func_pgo_var) = self.function_pgos.get(func_name).unwrap(); + + let pgo_pointer = func_pgo_var.as_pointer_value(); + let num_counters = func_record.mapping_regions.len(); + + rustc_llvm_coverage::emit_counter_increment( + builder, + module, + &pgo_pointer, + func_record.structural_hash, + num_counters.try_into().unwrap(), + counter_index, + ); + } + + pub fn emit_branch_increment<'ctx>( + &self, + builder: &Builder<'ink>, + context: &Context, + increment_intrinsic_func: &FunctionValue<'ink>, + func_name: &str, + ast_id: AstId, + ) { + let (func_record, func_pgo_var) = self.function_pgos.get(func_name).unwrap(); + + let pgo_pointer = func_pgo_var.as_pointer_value(); + let num_counters = func_record.mapping_regions.len(); + + let counter_index = match self.ast_counter_lookup.get(&ast_id) { + Some(counter_index) => counter_index, + None => { + // TODO - figure out why this happens + println!("Ast Not Registered: {} (from function {})", ast_id, func_name); + return; + } + }; + + rustc_llvm_coverage::emit_counter_increment_with_function( + builder, + context, + increment_intrinsic_func, + &pgo_pointer, + func_record.structural_hash, + num_counters.try_into().unwrap(), + *counter_index as u64, + ); + } + + pub fn finalize(&mut self, module: &Module<'ink>) { + run_instrumentation_lowering_pass(module); + } + + /// Internal function to generate for a function: + /// - FunctionRecord + /// - MappingRegions + fn generate_function_record(&mut self, implementation: &Implementation) -> FunctionRecord { + // Gather function information + + let func_name = implementation.name.clone(); + // TODO - hash strucutrally + let struct_hash = rustc_llvm_coverage::interfaces::hash_str(&func_name); + let func_filenames = vec![implementation.location.get_file_name().unwrap().to_string()]; + // TODO - file mapping table + let file_id = 1; + + // Map entire function + let mut mapping_region_generator = MappingRegionGenerator::new(file_id); + let func_map_region = mapping_region_generator.add_code_mapping_region(&implementation.location); + assert!(func_map_region.counter.id == 0); + + // DFS function statements + self.generate_coverage_records( + &implementation.statements, + &mut mapping_region_generator, + func_map_region.counter, + ); + + // TODO - determine if function is used + let is_used = true; + let written_coverage_header = &self.cov_mapping_header.as_mut().unwrap(); + + FunctionRecord::new( + func_name, + struct_hash, + func_filenames, + mapping_region_generator.expressions, + mapping_region_generator.mapping_regions, + is_used, + &written_coverage_header, + ) + } + + /// DFS algorithm to parse AST and generate coverage records for all branching + /// `parent_counter_id` is the counter id of the parent node, used for calculating "false" branches under if/case statements + /// TODO - explain or diagram what's going on here with + /// - last_false_counter: for chain branches + /// - last_true_counter: for recursing + fn generate_coverage_records( + &mut self, + ast_node_list: &Vec, + mapping_region_generator: &mut MappingRegionGenerator, + parent_counter: Counter, + ) { + for ast_node in ast_node_list { + // Only generate coverage records for control statements + let control_statement = match &ast_node.stmt { + AstStatement::ControlStatement(statement) => statement, + _ => continue, + }; + + // Track last counter (a_{n-1}) - useful for false branch calculations + // Must be initialized to parent counter in first iteration + let mut last_true_counter = parent_counter; + let mut last_false_counter = parent_counter; + + // + match control_statement { + AstControlStatement::If(statement) => { + // Loop through if/elif blocks + for block in &statement.blocks { + // Setup ast->id mapping, store region location + (last_true_counter, last_false_counter) = self.register_ast_list_as_branch_region( + &block.body, + mapping_region_generator, + last_true_counter, + last_false_counter, + ); + // Recurse into child blocks + self.generate_coverage_records( + &block.body, + mapping_region_generator, + last_true_counter, + ); + } + + // Else block ast->id mapping + (last_true_counter, last_false_counter) = self.register_ast_list_as_branch_region( + &statement.else_block, + mapping_region_generator, + last_true_counter, + last_false_counter, + ); + // Recurse into child blocks + self.generate_coverage_records( + &statement.else_block, + mapping_region_generator, + last_true_counter, + ); + } + AstControlStatement::ForLoop(statement) => { + // Loop through for loop body + (last_true_counter, last_false_counter) = self.register_ast_list_as_loop_branch_region( + &statement.body, + &statement.end, + mapping_region_generator, + last_true_counter, + last_false_counter, + ); + self.generate_coverage_records( + &statement.body, + mapping_region_generator, + last_true_counter, + ); + } + AstControlStatement::WhileLoop(statement) => { + // Loop through while loop body + (last_true_counter, last_false_counter) = self.register_ast_list_as_loop_branch_region( + &statement.body, + &statement.condition, + mapping_region_generator, + last_true_counter, + last_false_counter, + ); + self.generate_coverage_records( + &statement.body, + mapping_region_generator, + last_true_counter, + ); + } + AstControlStatement::RepeatLoop(statement) => { + // Loop through while loop body + (last_true_counter, last_false_counter) = self.register_ast_list_as_loop_branch_region( + &statement.body, + &statement.condition, + mapping_region_generator, + last_true_counter, + last_false_counter, + ); + self.generate_coverage_records( + &statement.body, + mapping_region_generator, + last_true_counter, + ); + } + AstControlStatement::Case(statement) => { + // Loop through case blocks + for block in &statement.case_blocks { + // Setup ast->id mapping, store region location + (last_true_counter, last_false_counter) = self.register_ast_list_as_branch_region( + &block.body, + mapping_region_generator, + last_true_counter, + last_false_counter, + ); + // Recurse + self.generate_coverage_records( + &block.body, + mapping_region_generator, + last_true_counter, + ); + } + + // Else block + (last_true_counter, last_false_counter) = self.register_ast_list_as_branch_region( + &statement.else_block, + mapping_region_generator, + last_true_counter, + last_false_counter, + ); + self.generate_coverage_records( + &statement.else_block, + mapping_region_generator, + last_true_counter, + ); + } + } + } + } + + // TODO - find me a better name + /// Registers a Vec as a region, spanning first and last + /// Returns the true and false counters for DFS + fn register_ast_list_as_branch_region( + &mut self, + ast_list: &Vec, + mapping_region_generator: &mut MappingRegionGenerator, + last_true_counter: Counter, + last_false_counter: Counter, + ) -> (Counter, Counter) { + if ast_list.is_empty() { + return (last_true_counter, last_false_counter); + } + + // Create a span from first_block -> last_block + let first_block = ast_list.first().unwrap(); + let last_block = ast_list.last().unwrap(); + let span = first_block.location.span(&last_block.location); + + // Map the span, store the counter id in the lookup table (key at first_block.ast_id) + let mapping_region = mapping_region_generator.add_branch_mapping_region(&span, last_false_counter); + self.ast_counter_lookup.insert(first_block.id, mapping_region.counter.id.try_into().unwrap()); + + (mapping_region.counter, mapping_region.false_counter) + } + + // TODO - find a better name + fn register_ast_list_as_loop_branch_region( + &mut self, + loop_body_ast_list: &Vec, + loop_condition_ast: &AstNode, + mapping_region_generator: &mut MappingRegionGenerator, + last_true_counter: Counter, + last_false_counter: Counter, + ) -> (Counter, Counter) { + if loop_body_ast_list.is_empty() { + return (last_true_counter, last_false_counter); + } + + // Create a counter for the condition ast + // this is a temporary hack, because false_counter will not actually create + // a counter that doesn't exist, only reference once + let condition_mapping_region = + mapping_region_generator.add_code_mapping_region(&loop_condition_ast.location); + + // Create a span from first_block -> last_block + let first_block = loop_body_ast_list.first().unwrap(); + let last_block = loop_body_ast_list.last().unwrap(); + let span = first_block.location.span(&last_block.location); + + // Map the span, store the counter id in the lookup table (key at first_block.ast_id) + let mapping_region = + mapping_region_generator.add_loop_branch_mapping_region(&span, condition_mapping_region.counter); + + // Map loop body -> true branch counter + // Map loop condition -> false branch counter + self.ast_counter_lookup.insert(first_block.id, mapping_region.counter.id.try_into().unwrap()); + self.ast_counter_lookup + .insert(loop_condition_ast.id, mapping_region.false_counter.id.try_into().unwrap()); + + (mapping_region.counter, mapping_region.false_counter) + } + + pub fn get_increment_function(&self, module: &Module<'ink>) -> FunctionValue<'ink> { + let increment_intrinsic = Intrinsic::find("llvm.instrprof.increment").unwrap(); + let increment_intrinsic_func = increment_intrinsic.get_declaration(module, &[]).unwrap(); + + increment_intrinsic_func + } + + pub fn sanitize_functions( + &mut self, + unit: &CompilationUnit, + llvm_index: &LlvmTypedIndex, + module: &Module<'ink>, + ) { + for implementation in &unit.implementations { + // Skip non-internal functions (external links + built-ins) + if implementation.linkage != LinkageType::Internal { + continue; + } + // Skip no-definition functions + // TODO - investigate which functions don't have definitions and why + if module.get_function(&implementation.name).is_none() { + println!("Skipping undefined function: {}", &implementation.name); + continue; + } + + let func = llvm_index.find_associated_implementation(&implementation.name).unwrap(); + + let context = module.get_context(); + let sanitizer_attribute_id = Attribute::get_named_enum_kind_id("sanitize_address"); + let sanitizer_attribute = context.create_enum_attribute(sanitizer_attribute_id, 0); + func.add_attribute(AttributeLoc::Function, sanitizer_attribute); + } + } +}