Skip to content

Improve JALR execution with JIT-cache #471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/riscv-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
set -e -u -o pipefail

# Install RISCOF
python3 -m pip install git+https://github.com/riscv/riscof
pip3 install git+https://github.com/riscv/riscof.git@d38859f85fe407bcacddd2efcd355ada4683aee4

set -x

Expand Down
Binary file added build/fibonacci.elf
Binary file not shown.
1 change: 1 addition & 0 deletions src/jit.c
Original file line number Diff line number Diff line change
Expand Up @@ -1864,6 +1864,7 @@ static void code_cache_flush(struct jit_state *state, riscv_t *rv)
state->offset = state->org_size;
state->n_blocks = 0;
set_reset(&state->set);
jit_cache_clear(rv->jit_cache);
clear_cache_hot(rv->block_cache, (clear_func_t) clear_hot);
return;
}
Expand Down
22 changes: 21 additions & 1 deletion src/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ void jit_translate(riscv_t *rv, block_t *block);
typedef void (*exec_block_func_t)(riscv_t *rv, uintptr_t);

#if RV32_HAS(T2C)
void t2c_compile(block_t *block, uint64_t mem_base);
void t2c_compile(riscv_t *, block_t *);
typedef void (*exec_t2c_func_t)(riscv_t *);

/* The jit-cache records the program counters and the entries of executable
* instructions generated by T2C. Like hardware cache, the old jit-cache will be
* replaced by the new one which uses the same slot.
*/

/* The size of jit-cache table should be the power of 2, thus, we can easily
* access the element by masking the program counter.
*/
#define N_JIT_CACHE_ENTRIES (1 << 12)

struct jit_cache {
uint64_t pc; /* program counter, easy to build LLVM IR with 64-bit width */
void *entry; /* entry of JIT-ed code */
};

struct jit_cache *jit_cache_init();
void jit_cache_exit(struct jit_cache *cache);
void jit_cache_update(struct jit_cache *cache, uint32_t pc, void *entry);
void jit_cache_clear(struct jit_cache *cache);
#endif
5 changes: 3 additions & 2 deletions src/riscv.c
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,7 @@ static void *t2c_runloop(void *arg)
pthread_mutex_lock(&rv->wait_queue_lock);
list_del_init(&entry->list);
pthread_mutex_unlock(&rv->wait_queue_lock);
t2c_compile(entry->block,
(uint64_t) ((memory_t *) PRIV(rv)->mem)->mem_base);
t2c_compile(rv, entry->block);
free(entry);
}
}
Expand Down Expand Up @@ -291,6 +290,7 @@ riscv_t *rv_create(riscv_user_t rv_attr)
mpool_create(sizeof(chain_entry_t) << BLOCK_IR_MAP_CAPACITY_BITS,
sizeof(chain_entry_t));
rv->jit_state = jit_state_init(CODE_CACHE_SIZE);
rv->jit_cache = jit_cache_init();
rv->block_cache = cache_create(BLOCK_MAP_CAPACITY_BITS);
assert(rv->block_cache);
#if RV32_HAS(T2C)
Expand Down Expand Up @@ -392,6 +392,7 @@ void rv_delete(riscv_t *rv)
#endif
mpool_destroy(rv->chain_entry_mp);
jit_state_exit(rv->jit_state);
jit_cache_exit(rv->jit_cache);
cache_free(rv->block_cache);
#endif
free(rv);
Expand Down
1 change: 1 addition & 0 deletions src/riscv_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ struct riscv_internal {
struct mpool *block_mp, *block_ir_mp;

void *jit_state;
void *jit_cache;
#if RV32_HAS(GDBSTUB)
/* gdbstub instance */
gdbstub_t gdbstub;
Expand Down
72 changes: 55 additions & 17 deletions src/t2c.c
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ FORCE_INLINE LLVMBasicBlockRef t2c_block_map_search(struct LLVM_block_map *map,
return NULL;
}

#define T2C_OP(inst, code) \
static void t2c_##inst( \
LLVMBuilderRef *builder UNUSED, LLVMTypeRef *param_types UNUSED, \
LLVMValueRef start UNUSED, LLVMBasicBlockRef *entry UNUSED, \
LLVMBuilderRef *taken_builder UNUSED, \
LLVMBuilderRef *untaken_builder UNUSED, uint64_t mem_base UNUSED, \
rv_insn_t *ir UNUSED) \
{ \
code; \
#define T2C_OP(inst, code) \
static void t2c_##inst( \
LLVMBuilderRef *builder UNUSED, LLVMTypeRef *param_types UNUSED, \
LLVMValueRef start UNUSED, LLVMBasicBlockRef *entry UNUSED, \
LLVMBuilderRef *taken_builder UNUSED, \
LLVMBuilderRef *untaken_builder UNUSED, riscv_t *rv UNUSED, \
uint64_t mem_base UNUSED, rv_insn_t *ir UNUSED) \
{ \
code; \
}

#define T2C_LLVM_GEN_ADDR(reg, rv_member, ir_member) \
Expand Down Expand Up @@ -135,6 +135,9 @@ FORCE_INLINE void t2c_gen_call_io_func(LLVMValueRef start,
&io_param, 1, "");
}

static LLVMTypeRef t2c_jit_cache_func_type;
static LLVMTypeRef t2c_jit_cache_struct_type;

#include "t2c_template.c"
#undef T2C_OP

Expand Down Expand Up @@ -174,14 +177,15 @@ typedef void (*t2c_codegen_block_func_t)(LLVMBuilderRef *builder UNUSED,
LLVMBasicBlockRef *entry UNUSED,
LLVMBuilderRef *taken_builder UNUSED,
LLVMBuilderRef *untaken_builder UNUSED,
riscv_t *rv UNUSED,
uint64_t mem_base UNUSED,
rv_insn_t *ir UNUSED);

static void t2c_trace_ebb(LLVMBuilderRef *builder,
LLVMTypeRef *param_types UNUSED,
LLVMValueRef start,
LLVMBasicBlockRef *entry,
uint64_t mem_base,
riscv_t *rv,
rv_insn_t *ir,
set_t *set,
struct LLVM_block_map *map)
Expand All @@ -194,7 +198,8 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,

while (1) {
((t2c_codegen_block_func_t) dispatch_table[ir->opcode])(
builder, param_types, start, entry, &tk, &utk, mem_base, ir);
builder, param_types, start, entry, &tk, &utk, rv,
(uint64_t) ((memory_t *) PRIV(rv)->mem)->mem_base, ir);
if (!ir->next)
break;
ir = ir->next;
Expand All @@ -214,8 +219,7 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,
LLVMPositionBuilderAtEnd(untaken_builder, untaken_entry);
LLVMBuildBr(utk, untaken_entry);
t2c_trace_ebb(&untaken_builder, param_types, start,
&untaken_entry, mem_base, ir->branch_untaken, set,
map);
&untaken_entry, rv, ir->branch_untaken, set, map);
}
}
if (ir->branch_taken) {
Expand All @@ -230,13 +234,13 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,
LLVMPositionBuilderAtEnd(taken_builder, taken_entry);
LLVMBuildBr(tk, taken_entry);
t2c_trace_ebb(&taken_builder, param_types, start, &taken_entry,
mem_base, ir->branch_taken, set, map);
rv, ir->branch_taken, set, map);
}
}
}
}

void t2c_compile(block_t *block, uint64_t mem_base)
void t2c_compile(riscv_t *rv, block_t *block)
{
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
LLVMTypeRef io_members[] = {
Expand All @@ -254,6 +258,16 @@ void t2c_compile(block_t *block, uint64_t mem_base)
LLVMTypeRef param_types[] = {LLVMPointerType(struct_rv, 0)};
LLVMValueRef start = LLVMAddFunction(
module, "start", LLVMFunctionType(LLVMVoidType(), param_types, 1, 0));

LLVMTypeRef t2c_args[1] = {LLVMInt64Type()};
t2c_jit_cache_func_type =
LLVMFunctionType(LLVMVoidType(), t2c_args, 1, false);

/* Notice to the alignment */
LLVMTypeRef jit_cache_memb[2] = {LLVMInt64Type(),
LLVMPointerType(LLVMVoidType(), 0)};
t2c_jit_cache_struct_type = LLVMStructType(jit_cache_memb, 2, false);

LLVMBasicBlockRef first_block = LLVMAppendBasicBlock(start, "first_block");
LLVMBuilderRef first_builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(first_builder, first_block);
Expand All @@ -266,8 +280,8 @@ void t2c_compile(block_t *block, uint64_t mem_base)
struct LLVM_block_map map;
map.count = 0;
/* Translate custon IR into LLVM IR */
t2c_trace_ebb(&builder, param_types, start, &entry, mem_base,
block->ir_head, &set, &map);
t2c_trace_ebb(&builder, param_types, start, &entry, rv, block->ir_head,
&set, &map);
/* Offload LLVM IR to LLVM backend */
char *error = NULL, *triple = LLVMGetDefaultTargetTriple();
LLVMExecutionEngineRef engine;
Expand Down Expand Up @@ -298,5 +312,29 @@ void t2c_compile(block_t *block, uint64_t mem_base)

/* Return the function pointer of T2C generated machine code */
block->func = (exec_t2c_func_t) LLVMGetPointerToGlobal(engine, start);
jit_cache_update(rv->jit_cache, block->pc_start, block->func);
block->hot2 = true;
}

struct jit_cache *jit_cache_init()
{
return calloc(N_JIT_CACHE_ENTRIES, sizeof(struct jit_cache));
}

void jit_cache_exit(struct jit_cache *cache)
{
free(cache);
}

void jit_cache_update(struct jit_cache *cache, uint32_t pc, void *entry)
{
uint32_t pos = pc & (N_JIT_CACHE_ENTRIES - 1);

cache[pos].pc = pc;
cache[pos].entry = entry;
}

void jit_cache_clear(struct jit_cache *cache)
{
memset(cache, 0, N_JIT_CACHE_ENTRIES * sizeof(struct jit_cache));
}
72 changes: 63 additions & 9 deletions src/t2c_template.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,63 @@ T2C_OP(jal, {
}
})

FORCE_INLINE void t2c_jit_cache_helper(LLVMBuilderRef *builder,
LLVMValueRef start,
LLVMValueRef addr,
riscv_t *rv,
rv_insn_t *ir)
{
LLVMBasicBlockRef true_path = LLVMAppendBasicBlock(start, "");
LLVMBuilderRef true_builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(true_builder, true_path);

LLVMBasicBlockRef false_path = LLVMAppendBasicBlock(start, "");
LLVMBuilderRef false_builder = LLVMCreateBuilder();
LLVMPositionBuilderAtEnd(false_builder, false_path);

/* get jit-cache base address */
LLVMValueRef base = LLVMConstIntToPtr(
LLVMConstInt(LLVMInt64Type(), (long) rv->jit_cache, false),
LLVMPointerType(t2c_jit_cache_struct_type, 0));

/* get index */
LLVMValueRef hash = LLVMBuildAnd(
*builder, addr,
LLVMConstInt(LLVMInt32Type(), N_JIT_CACHE_ENTRIES - 1, false), "");

/* get jit_cache_t::pc */
LLVMValueRef cast =
LLVMBuildIntCast2(*builder, hash, LLVMInt64Type(), false, "");
LLVMValueRef element_ptr = LLVMBuildInBoundsGEP2(
*builder, t2c_jit_cache_struct_type, base, &cast, 1, "");
LLVMValueRef pc_ptr = LLVMBuildStructGEP2(
*builder, t2c_jit_cache_struct_type, element_ptr, 0, "");
LLVMValueRef pc = LLVMBuildLoad2(*builder, LLVMInt32Type(), pc_ptr, "");

/* compare with calculated destination */
LLVMValueRef cmp = LLVMBuildICmp(*builder, LLVMIntEQ, pc, addr, "");

LLVMBuildCondBr(*builder, cmp, true_path, false_path);

/* get jit_cache_t::entry */
LLVMValueRef entry_ptr = LLVMBuildStructGEP2(
true_builder, t2c_jit_cache_struct_type, element_ptr, 1, "");

/* invoke T2C JIT-ed code */
LLVMValueRef t2c_args[1] = {
LLVMConstInt(LLVMInt64Type(), (long) rv, false)};

LLVMBuildCall2(true_builder, t2c_jit_cache_func_type,
LLVMBuildLoad2(true_builder, LLVMInt64Type(), entry_ptr, ""),
t2c_args, 1, "");
LLVMBuildRetVoid(true_builder);

/* return to interpreter if cache-miss */
LLVMBuildStore(false_builder, addr,
t2c_gen_PC_addr(start, &false_builder, ir));
LLVMBuildRetVoid(false_builder);
}

T2C_OP(jalr, {
if (ir->rd)
T2C_LLVM_GEN_STORE_IMM32(*builder, ir->pc + 4,
Expand All @@ -40,8 +97,7 @@ T2C_OP(jalr, {
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
val_rs1 = T2C_LLVM_GEN_ALU32_IMM(Add, val_rs1, ir->imm);
val_rs1 = T2C_LLVM_GEN_ALU32_IMM(And, val_rs1, ~1U);
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
LLVMBuildRetVoid(*builder);
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
})

#define BRANCH_FUNC(type, cond) \
Expand Down Expand Up @@ -672,8 +728,7 @@ T2C_OP(clwsp, {

T2C_OP(cjr, {
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
LLVMBuildRetVoid(*builder);
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
})

T2C_OP(cmv, {
Expand All @@ -692,8 +747,7 @@ T2C_OP(cjalr, {
T2C_LLVM_GEN_STORE_IMM32(*builder, ir->pc + 2,
t2c_gen_ra_addr(start, builder, ir));
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
LLVMBuildRetVoid(*builder);
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
})

T2C_OP(cadd, {
Expand Down Expand Up @@ -785,15 +839,15 @@ T2C_OP(fuse5, {
switch (fuse[i].opcode) {
case rv_insn_slli:
t2c_slli(builder, param_types, start, entry, taken_builder,
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
break;
case rv_insn_srli:
t2c_srli(builder, param_types, start, entry, taken_builder,
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
break;
case rv_insn_srai:
t2c_srai(builder, param_types, start, entry, taken_builder,
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
break;
default:
__UNREACHABLE;
Expand Down
43 changes: 43 additions & 0 deletions tests/fibonacci.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
fib:
li a5, 1
bleu a0, a5, .L3
addi sp, sp, -16
sw ra, 12(sp)
sw s0, 8(sp)
sw s1, 4(sp)
mv s0, a0
addi a0, a0, -1
la t0, fib
jalr ra, 0(t0)
mv s1, a0
addi a0, s0, -2
la t0, fib
jalr ra, 0(t0)
add a0, s1, a0
lw ra, 12(sp)
lw s0, 8(sp)
lw s1, 4(sp)
addi sp, sp, 16
jr ra
.L3:
li a0, 1
ret
.LC0:
.string "%d\n"
.text
.align 1
.globl main
.type main, @function
main:
addi sp, sp, -16
sw ra, 12(sp)
li a0, 42
call fib
mv a1, a0
lui a0, %hi(.LC0)
addi a0, a0, %lo(.LC0)
call printf
li a0, 0
lw ra, 12(sp)
addi sp, sp, 16
jr ra