Skip to content

Commit 34c3db0

Browse files
authored
Merge pull request #471 from vacantron/t2c/jalr
Improve `JALR` execution with JIT-cache
2 parents 9759ad2 + f5d04fb commit 34c3db0

File tree

9 files changed

+188
-30
lines changed

9 files changed

+188
-30
lines changed

.ci/riscv-tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
set -e -u -o pipefail
44

55
# Install RISCOF
6-
python3 -m pip install git+https://github.com/riscv/riscof
6+
pip3 install git+https://github.com/riscv/riscof.git@d38859f85fe407bcacddd2efcd355ada4683aee4
77

88
set -x
99

build/fibonacci.elf

74.6 KB
Binary file not shown.

src/jit.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,6 +1864,7 @@ static void code_cache_flush(struct jit_state *state, riscv_t *rv)
18641864
state->offset = state->org_size;
18651865
state->n_blocks = 0;
18661866
set_reset(&state->set);
1867+
jit_cache_clear(rv->jit_cache);
18671868
clear_cache_hot(rv->block_cache, (clear_func_t) clear_hot);
18681869
return;
18691870
}

src/jit.h

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,26 @@ void jit_translate(riscv_t *rv, block_t *block);
5151
typedef void (*exec_block_func_t)(riscv_t *rv, uintptr_t);
5252

5353
#if RV32_HAS(T2C)
54-
void t2c_compile(block_t *block, uint64_t mem_base);
54+
void t2c_compile(riscv_t *, block_t *);
5555
typedef void (*exec_t2c_func_t)(riscv_t *);
56+
57+
/* The jit-cache records the program counters and the entries of executable
58+
* instructions generated by T2C. Like hardware cache, the old jit-cache will be
59+
* replaced by the new one which uses the same slot.
60+
*/
61+
62+
/* The size of jit-cache table should be the power of 2, thus, we can easily
63+
* access the element by masking the program counter.
64+
*/
65+
#define N_JIT_CACHE_ENTRIES (1 << 12)
66+
67+
struct jit_cache {
68+
uint64_t pc; /* program counter, easy to build LLVM IR with 64-bit width */
69+
void *entry; /* entry of JIT-ed code */
70+
};
71+
72+
struct jit_cache *jit_cache_init();
73+
void jit_cache_exit(struct jit_cache *cache);
74+
void jit_cache_update(struct jit_cache *cache, uint32_t pc, void *entry);
75+
void jit_cache_clear(struct jit_cache *cache);
5676
#endif

src/riscv.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,7 @@ static void *t2c_runloop(void *arg)
199199
pthread_mutex_lock(&rv->wait_queue_lock);
200200
list_del_init(&entry->list);
201201
pthread_mutex_unlock(&rv->wait_queue_lock);
202-
t2c_compile(entry->block,
203-
(uint64_t) ((memory_t *) PRIV(rv)->mem)->mem_base);
202+
t2c_compile(rv, entry->block);
204203
free(entry);
205204
}
206205
}
@@ -291,6 +290,7 @@ riscv_t *rv_create(riscv_user_t rv_attr)
291290
mpool_create(sizeof(chain_entry_t) << BLOCK_IR_MAP_CAPACITY_BITS,
292291
sizeof(chain_entry_t));
293292
rv->jit_state = jit_state_init(CODE_CACHE_SIZE);
293+
rv->jit_cache = jit_cache_init();
294294
rv->block_cache = cache_create(BLOCK_MAP_CAPACITY_BITS);
295295
assert(rv->block_cache);
296296
#if RV32_HAS(T2C)
@@ -392,6 +392,7 @@ void rv_delete(riscv_t *rv)
392392
#endif
393393
mpool_destroy(rv->chain_entry_mp);
394394
jit_state_exit(rv->jit_state);
395+
jit_cache_exit(rv->jit_cache);
395396
cache_free(rv->block_cache);
396397
#endif
397398
free(rv);

src/riscv_private.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ struct riscv_internal {
155155
struct mpool *block_mp, *block_ir_mp;
156156

157157
void *jit_state;
158+
void *jit_cache;
158159
#if RV32_HAS(GDBSTUB)
159160
/* gdbstub instance */
160161
gdbstub_t gdbstub;

src/t2c.c

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,15 @@ FORCE_INLINE LLVMBasicBlockRef t2c_block_map_search(struct LLVM_block_map *map,
4949
return NULL;
5050
}
5151

52-
#define T2C_OP(inst, code) \
53-
static void t2c_##inst( \
54-
LLVMBuilderRef *builder UNUSED, LLVMTypeRef *param_types UNUSED, \
55-
LLVMValueRef start UNUSED, LLVMBasicBlockRef *entry UNUSED, \
56-
LLVMBuilderRef *taken_builder UNUSED, \
57-
LLVMBuilderRef *untaken_builder UNUSED, uint64_t mem_base UNUSED, \
58-
rv_insn_t *ir UNUSED) \
59-
{ \
60-
code; \
52+
#define T2C_OP(inst, code) \
53+
static void t2c_##inst( \
54+
LLVMBuilderRef *builder UNUSED, LLVMTypeRef *param_types UNUSED, \
55+
LLVMValueRef start UNUSED, LLVMBasicBlockRef *entry UNUSED, \
56+
LLVMBuilderRef *taken_builder UNUSED, \
57+
LLVMBuilderRef *untaken_builder UNUSED, riscv_t *rv UNUSED, \
58+
uint64_t mem_base UNUSED, rv_insn_t *ir UNUSED) \
59+
{ \
60+
code; \
6161
}
6262

6363
#define T2C_LLVM_GEN_ADDR(reg, rv_member, ir_member) \
@@ -135,6 +135,9 @@ FORCE_INLINE void t2c_gen_call_io_func(LLVMValueRef start,
135135
&io_param, 1, "");
136136
}
137137

138+
static LLVMTypeRef t2c_jit_cache_func_type;
139+
static LLVMTypeRef t2c_jit_cache_struct_type;
140+
138141
#include "t2c_template.c"
139142
#undef T2C_OP
140143

@@ -174,14 +177,15 @@ typedef void (*t2c_codegen_block_func_t)(LLVMBuilderRef *builder UNUSED,
174177
LLVMBasicBlockRef *entry UNUSED,
175178
LLVMBuilderRef *taken_builder UNUSED,
176179
LLVMBuilderRef *untaken_builder UNUSED,
180+
riscv_t *rv UNUSED,
177181
uint64_t mem_base UNUSED,
178182
rv_insn_t *ir UNUSED);
179183

180184
static void t2c_trace_ebb(LLVMBuilderRef *builder,
181185
LLVMTypeRef *param_types UNUSED,
182186
LLVMValueRef start,
183187
LLVMBasicBlockRef *entry,
184-
uint64_t mem_base,
188+
riscv_t *rv,
185189
rv_insn_t *ir,
186190
set_t *set,
187191
struct LLVM_block_map *map)
@@ -194,7 +198,8 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,
194198

195199
while (1) {
196200
((t2c_codegen_block_func_t) dispatch_table[ir->opcode])(
197-
builder, param_types, start, entry, &tk, &utk, mem_base, ir);
201+
builder, param_types, start, entry, &tk, &utk, rv,
202+
(uint64_t) ((memory_t *) PRIV(rv)->mem)->mem_base, ir);
198203
if (!ir->next)
199204
break;
200205
ir = ir->next;
@@ -214,8 +219,7 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,
214219
LLVMPositionBuilderAtEnd(untaken_builder, untaken_entry);
215220
LLVMBuildBr(utk, untaken_entry);
216221
t2c_trace_ebb(&untaken_builder, param_types, start,
217-
&untaken_entry, mem_base, ir->branch_untaken, set,
218-
map);
222+
&untaken_entry, rv, ir->branch_untaken, set, map);
219223
}
220224
}
221225
if (ir->branch_taken) {
@@ -230,13 +234,13 @@ static void t2c_trace_ebb(LLVMBuilderRef *builder,
230234
LLVMPositionBuilderAtEnd(taken_builder, taken_entry);
231235
LLVMBuildBr(tk, taken_entry);
232236
t2c_trace_ebb(&taken_builder, param_types, start, &taken_entry,
233-
mem_base, ir->branch_taken, set, map);
237+
rv, ir->branch_taken, set, map);
234238
}
235239
}
236240
}
237241
}
238242

239-
void t2c_compile(block_t *block, uint64_t mem_base)
243+
void t2c_compile(riscv_t *rv, block_t *block)
240244
{
241245
LLVMModuleRef module = LLVMModuleCreateWithName("my_module");
242246
LLVMTypeRef io_members[] = {
@@ -254,6 +258,16 @@ void t2c_compile(block_t *block, uint64_t mem_base)
254258
LLVMTypeRef param_types[] = {LLVMPointerType(struct_rv, 0)};
255259
LLVMValueRef start = LLVMAddFunction(
256260
module, "start", LLVMFunctionType(LLVMVoidType(), param_types, 1, 0));
261+
262+
LLVMTypeRef t2c_args[1] = {LLVMInt64Type()};
263+
t2c_jit_cache_func_type =
264+
LLVMFunctionType(LLVMVoidType(), t2c_args, 1, false);
265+
266+
/* Notice to the alignment */
267+
LLVMTypeRef jit_cache_memb[2] = {LLVMInt64Type(),
268+
LLVMPointerType(LLVMVoidType(), 0)};
269+
t2c_jit_cache_struct_type = LLVMStructType(jit_cache_memb, 2, false);
270+
257271
LLVMBasicBlockRef first_block = LLVMAppendBasicBlock(start, "first_block");
258272
LLVMBuilderRef first_builder = LLVMCreateBuilder();
259273
LLVMPositionBuilderAtEnd(first_builder, first_block);
@@ -266,8 +280,8 @@ void t2c_compile(block_t *block, uint64_t mem_base)
266280
struct LLVM_block_map map;
267281
map.count = 0;
268282
/* Translate custon IR into LLVM IR */
269-
t2c_trace_ebb(&builder, param_types, start, &entry, mem_base,
270-
block->ir_head, &set, &map);
283+
t2c_trace_ebb(&builder, param_types, start, &entry, rv, block->ir_head,
284+
&set, &map);
271285
/* Offload LLVM IR to LLVM backend */
272286
char *error = NULL, *triple = LLVMGetDefaultTargetTriple();
273287
LLVMExecutionEngineRef engine;
@@ -298,5 +312,29 @@ void t2c_compile(block_t *block, uint64_t mem_base)
298312

299313
/* Return the function pointer of T2C generated machine code */
300314
block->func = (exec_t2c_func_t) LLVMGetPointerToGlobal(engine, start);
315+
jit_cache_update(rv->jit_cache, block->pc_start, block->func);
301316
block->hot2 = true;
302317
}
318+
319+
struct jit_cache *jit_cache_init()
320+
{
321+
return calloc(N_JIT_CACHE_ENTRIES, sizeof(struct jit_cache));
322+
}
323+
324+
void jit_cache_exit(struct jit_cache *cache)
325+
{
326+
free(cache);
327+
}
328+
329+
void jit_cache_update(struct jit_cache *cache, uint32_t pc, void *entry)
330+
{
331+
uint32_t pos = pc & (N_JIT_CACHE_ENTRIES - 1);
332+
333+
cache[pos].pc = pc;
334+
cache[pos].entry = entry;
335+
}
336+
337+
void jit_cache_clear(struct jit_cache *cache)
338+
{
339+
memset(cache, 0, N_JIT_CACHE_ENTRIES * sizeof(struct jit_cache));
340+
}

src/t2c_template.c

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,63 @@ T2C_OP(jal, {
3232
}
3333
})
3434

35+
FORCE_INLINE void t2c_jit_cache_helper(LLVMBuilderRef *builder,
36+
LLVMValueRef start,
37+
LLVMValueRef addr,
38+
riscv_t *rv,
39+
rv_insn_t *ir)
40+
{
41+
LLVMBasicBlockRef true_path = LLVMAppendBasicBlock(start, "");
42+
LLVMBuilderRef true_builder = LLVMCreateBuilder();
43+
LLVMPositionBuilderAtEnd(true_builder, true_path);
44+
45+
LLVMBasicBlockRef false_path = LLVMAppendBasicBlock(start, "");
46+
LLVMBuilderRef false_builder = LLVMCreateBuilder();
47+
LLVMPositionBuilderAtEnd(false_builder, false_path);
48+
49+
/* get jit-cache base address */
50+
LLVMValueRef base = LLVMConstIntToPtr(
51+
LLVMConstInt(LLVMInt64Type(), (long) rv->jit_cache, false),
52+
LLVMPointerType(t2c_jit_cache_struct_type, 0));
53+
54+
/* get index */
55+
LLVMValueRef hash = LLVMBuildAnd(
56+
*builder, addr,
57+
LLVMConstInt(LLVMInt32Type(), N_JIT_CACHE_ENTRIES - 1, false), "");
58+
59+
/* get jit_cache_t::pc */
60+
LLVMValueRef cast =
61+
LLVMBuildIntCast2(*builder, hash, LLVMInt64Type(), false, "");
62+
LLVMValueRef element_ptr = LLVMBuildInBoundsGEP2(
63+
*builder, t2c_jit_cache_struct_type, base, &cast, 1, "");
64+
LLVMValueRef pc_ptr = LLVMBuildStructGEP2(
65+
*builder, t2c_jit_cache_struct_type, element_ptr, 0, "");
66+
LLVMValueRef pc = LLVMBuildLoad2(*builder, LLVMInt32Type(), pc_ptr, "");
67+
68+
/* compare with calculated destination */
69+
LLVMValueRef cmp = LLVMBuildICmp(*builder, LLVMIntEQ, pc, addr, "");
70+
71+
LLVMBuildCondBr(*builder, cmp, true_path, false_path);
72+
73+
/* get jit_cache_t::entry */
74+
LLVMValueRef entry_ptr = LLVMBuildStructGEP2(
75+
true_builder, t2c_jit_cache_struct_type, element_ptr, 1, "");
76+
77+
/* invoke T2C JIT-ed code */
78+
LLVMValueRef t2c_args[1] = {
79+
LLVMConstInt(LLVMInt64Type(), (long) rv, false)};
80+
81+
LLVMBuildCall2(true_builder, t2c_jit_cache_func_type,
82+
LLVMBuildLoad2(true_builder, LLVMInt64Type(), entry_ptr, ""),
83+
t2c_args, 1, "");
84+
LLVMBuildRetVoid(true_builder);
85+
86+
/* return to interpreter if cache-miss */
87+
LLVMBuildStore(false_builder, addr,
88+
t2c_gen_PC_addr(start, &false_builder, ir));
89+
LLVMBuildRetVoid(false_builder);
90+
}
91+
3592
T2C_OP(jalr, {
3693
if (ir->rd)
3794
T2C_LLVM_GEN_STORE_IMM32(*builder, ir->pc + 4,
@@ -40,8 +97,7 @@ T2C_OP(jalr, {
4097
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
4198
val_rs1 = T2C_LLVM_GEN_ALU32_IMM(Add, val_rs1, ir->imm);
4299
val_rs1 = T2C_LLVM_GEN_ALU32_IMM(And, val_rs1, ~1U);
43-
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
44-
LLVMBuildRetVoid(*builder);
100+
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
45101
})
46102

47103
#define BRANCH_FUNC(type, cond) \
@@ -672,8 +728,7 @@ T2C_OP(clwsp, {
672728

673729
T2C_OP(cjr, {
674730
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
675-
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
676-
LLVMBuildRetVoid(*builder);
731+
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
677732
})
678733

679734
T2C_OP(cmv, {
@@ -692,8 +747,7 @@ T2C_OP(cjalr, {
692747
T2C_LLVM_GEN_STORE_IMM32(*builder, ir->pc + 2,
693748
t2c_gen_ra_addr(start, builder, ir));
694749
T2C_LLVM_GEN_LOAD_VMREG(rs1, 32, t2c_gen_rs1_addr(start, builder, ir));
695-
LLVMBuildStore(*builder, val_rs1, t2c_gen_PC_addr(start, builder, ir));
696-
LLVMBuildRetVoid(*builder);
750+
t2c_jit_cache_helper(builder, start, val_rs1, rv, ir);
697751
})
698752

699753
T2C_OP(cadd, {
@@ -785,15 +839,15 @@ T2C_OP(fuse5, {
785839
switch (fuse[i].opcode) {
786840
case rv_insn_slli:
787841
t2c_slli(builder, param_types, start, entry, taken_builder,
788-
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
842+
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
789843
break;
790844
case rv_insn_srli:
791845
t2c_srli(builder, param_types, start, entry, taken_builder,
792-
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
846+
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
793847
break;
794848
case rv_insn_srai:
795849
t2c_srai(builder, param_types, start, entry, taken_builder,
796-
untaken_builder, mem_base, (rv_insn_t *) (&fuse[i]));
850+
untaken_builder, rv, mem_base, (rv_insn_t *) (&fuse[i]));
797851
break;
798852
default:
799853
__UNREACHABLE;

tests/fibonacci.s

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
fib:
2+
li a5, 1
3+
bleu a0, a5, .L3
4+
addi sp, sp, -16
5+
sw ra, 12(sp)
6+
sw s0, 8(sp)
7+
sw s1, 4(sp)
8+
mv s0, a0
9+
addi a0, a0, -1
10+
la t0, fib
11+
jalr ra, 0(t0)
12+
mv s1, a0
13+
addi a0, s0, -2
14+
la t0, fib
15+
jalr ra, 0(t0)
16+
add a0, s1, a0
17+
lw ra, 12(sp)
18+
lw s0, 8(sp)
19+
lw s1, 4(sp)
20+
addi sp, sp, 16
21+
jr ra
22+
.L3:
23+
li a0, 1
24+
ret
25+
.LC0:
26+
.string "%d\n"
27+
.text
28+
.align 1
29+
.globl main
30+
.type main, @function
31+
main:
32+
addi sp, sp, -16
33+
sw ra, 12(sp)
34+
li a0, 42
35+
call fib
36+
mv a1, a0
37+
lui a0, %hi(.LC0)
38+
addi a0, a0, %lo(.LC0)
39+
call printf
40+
li a0, 0
41+
lw ra, 12(sp)
42+
addi sp, sp, 16
43+
jr ra

0 commit comments

Comments
 (0)