Skip to content

Commit f71bee4

Browse files
wanda-phiwhitequark
authored andcommitted
sim: evaluate simulator commands in-place instead of compiling them.
1 parent 967dabc commit f71bee4

File tree

4 files changed

+235
-24
lines changed

4 files changed

+235
-24
lines changed

amaranth/hdl/_ast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
__all__ = [
1919
"SyntaxError", "SyntaxWarning",
2020
"Shape", "signed", "unsigned", "ShapeCastable", "ShapeLike",
21-
"Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "SwitchValue",
21+
"Value", "Const", "C", "AnyValue", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Concat", "SwitchValue",
2222
"Array", "ArrayProxy",
2323
"Signal", "ClockSignal", "ResetSignal",
2424
"ValueCastable", "ValueLike",

amaranth/sim/_pycoro.py

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ..hdl._mem import MemorySimRead, MemorySimWrite
66
from .core import Tick, Settle, Delay, Passive, Active
77
from ._base import BaseProcess, BaseMemoryState
8-
from ._pyrtl import _ValueCompiler, _RHSValueCompiler, _StatementCompiler
8+
from ._pyeval import eval_value, eval_assign
99

1010

1111
__all__ = ["PyCoroProcess"]
@@ -28,11 +28,6 @@ def reset(self):
2828
self.passive = False
2929

3030
self.coroutine = self.constructor()
31-
self.exec_locals = {
32-
"slots": self.state.slots,
33-
"result": None,
34-
**_ValueCompiler.helpers
35-
}
3631
self.waits_on = SignalSet()
3732

3833
def src_loc(self):
@@ -87,14 +82,11 @@ def run(self):
8782
if isinstance(command, ValueCastable):
8883
command = Value.cast(command)
8984
if isinstance(command, Value):
90-
exec(_RHSValueCompiler.compile(self.state, command, mode="curr"),
91-
self.exec_locals)
92-
response = Const(self.exec_locals["result"], command.shape()).value
93-
94-
elif isinstance(command, Statement):
95-
exec(_StatementCompiler.compile(self.state, command),
96-
self.exec_locals)
97-
if isinstance(command, Assign) and self.testbench:
85+
response = eval_value(self.state, command)
86+
87+
elif isinstance(command, Assign):
88+
eval_assign(self.state, command.lhs, eval_value(self.state, command.rhs))
89+
if self.testbench:
9890
return True # assignment; run a delta cycle
9991

10092
elif type(command) is Tick:
@@ -132,21 +124,15 @@ def run(self):
132124
self.passive = False
133125

134126
elif type(command) is MemorySimRead:
135-
exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"),
136-
self.exec_locals)
137-
addr = Const(self.exec_locals["result"], command._addr.shape()).value
127+
addr = eval_value(self.state, command._addr)
138128
index = self.state.get_memory(command._memory)
139129
state = self.state.slots[index]
140130
assert isinstance(state, BaseMemoryState)
141131
response = state.read(addr)
142132

143133
elif type(command) is MemorySimWrite:
144-
exec(_RHSValueCompiler.compile(self.state, command._addr, mode="curr"),
145-
self.exec_locals)
146-
addr = Const(self.exec_locals["result"], command._addr.shape()).value
147-
exec(_RHSValueCompiler.compile(self.state, command._data, mode="curr"),
148-
self.exec_locals)
149-
data = Const(self.exec_locals["result"], command._data.shape()).value
134+
addr = eval_value(self.state, command._addr)
135+
data = eval_value(self.state, command._data)
150136
index = self.state.get_memory(command._memory)
151137
state = self.state.slots[index]
152138
assert isinstance(state, BaseMemoryState)

amaranth/sim/_pyeval.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from amaranth.hdl._ast import *
2+
3+
4+
def _eval_matches(test, patterns):
5+
if patterns is None:
6+
return True
7+
for pattern in patterns:
8+
if isinstance(pattern, str):
9+
mask = int("".join("0" if b == "-" else "1" for b in pattern), 2)
10+
value = int("".join("0" if b == "-" else b for b in pattern), 2)
11+
if value == (mask & test):
12+
return True
13+
else:
14+
if pattern == test:
15+
return True
16+
return False
17+
18+
19+
def eval_value(sim, value):
20+
if isinstance(value, Const):
21+
return value.value
22+
elif isinstance(value, Operator):
23+
if len(value.operands) == 1:
24+
op_a = eval_value(sim, value.operands[0])
25+
if value.operator in ("u", "s"):
26+
width = value.shape().width
27+
res = op_a
28+
res &= (1 << width) - 1
29+
if value.operator == "s" and res & (1 << (width - 1)):
30+
res |= -1 << (width - 1)
31+
return res
32+
elif value.operator == "-":
33+
return -op_a
34+
elif value.operator == "~":
35+
shape = value.shape()
36+
if shape.signed:
37+
return ~op_a
38+
else:
39+
return ~op_a & ((1 << shape.width) - 1)
40+
elif value.operator in ("b", "r|"):
41+
return int(op_a != 0)
42+
elif value.operator == "r&":
43+
width = value.operands[0].shape().width
44+
mask = (1 << width) - 1
45+
return int((op_a & mask) == mask)
46+
elif value.operator == "r^":
47+
width = value.operands[0].shape().width
48+
mask = (1 << width) - 1
49+
# Believe it or not, this is the fastest way to compute a sideways XOR in Python.
50+
return format(op_a & mask, 'b').count('1') % 2
51+
elif len(value.operands) == 2:
52+
op_a = eval_value(sim, value.operands[0])
53+
op_b = eval_value(sim, value.operands[1])
54+
if value.operator == "|":
55+
return op_a | op_b
56+
elif value.operator == "&":
57+
return op_a & op_b
58+
elif value.operator == "^":
59+
return op_a ^ op_b
60+
elif value.operator == "+":
61+
return op_a + op_b
62+
elif value.operator == "-":
63+
return op_a - op_b
64+
elif value.operator == "*":
65+
return op_a * op_b
66+
elif value.operator == "//":
67+
if op_b == 0:
68+
return 0
69+
return op_a // op_b
70+
elif value.operator == "%":
71+
if op_b == 0:
72+
return 0
73+
return op_a % op_b
74+
elif value.operator == "<<":
75+
return op_a << op_b
76+
elif value.operator == ">>":
77+
return op_a >> op_b
78+
elif value.operator == "==":
79+
return int(op_a == op_b)
80+
elif value.operator == "!=":
81+
return int(op_a != op_b)
82+
elif value.operator == "<":
83+
return int(op_a < op_b)
84+
elif value.operator == "<=":
85+
return int(op_a <= op_b)
86+
elif value.operator == ">":
87+
return int(op_a > op_b)
88+
elif value.operator == ">=":
89+
return int(op_a >= op_b)
90+
assert False # :nocov:
91+
elif isinstance(value, Slice):
92+
res = eval_value(sim, value.value)
93+
res >>= value.start
94+
width = value.stop - value.start
95+
return res & ((1 << width) - 1)
96+
elif isinstance(value, Part):
97+
res = eval_value(sim, value.value)
98+
offset = eval_value(sim, value.offset)
99+
offset *= value.stride
100+
res >>= offset
101+
return res & ((1 << value.width) - 1)
102+
elif isinstance(value, Concat):
103+
res = 0
104+
pos = 0
105+
for part in value.parts:
106+
width = len(part)
107+
part = eval_value(sim, part)
108+
part &= (1 << width) - 1
109+
res |= part << pos
110+
pos += width
111+
return res
112+
elif isinstance(value, SwitchValue):
113+
test = eval_value(sim, value.test)
114+
for patterns, val in value.cases:
115+
if _eval_matches(test, patterns):
116+
return eval_value(sim, val)
117+
return 0
118+
elif isinstance(value, Signal):
119+
slot = sim.get_signal(value)
120+
return sim.slots[slot].curr
121+
elif isinstance(value, (ResetSignal, ClockSignal, AnyValue, Initial)):
122+
raise ValueError(f"Value {value!r} cannot be used in simulation")
123+
else:
124+
assert False # :nocov:
125+
126+
127+
def _eval_assign_inner(sim, lhs, lhs_start, rhs, rhs_len):
128+
if isinstance(lhs, Operator) and lhs.operator in ("u", "s"):
129+
_eval_assign_inner(sim, lhs.operands[0], lhs_start, rhs, rhs_len)
130+
elif isinstance(lhs, Signal):
131+
lhs_stop = lhs_start + rhs_len
132+
if lhs_stop > len(lhs):
133+
lhs_stop = len(lhs)
134+
if lhs_start >= len(lhs):
135+
return
136+
slot = sim.get_signal(lhs)
137+
value = sim.slots[slot].next
138+
mask = (1 << lhs_stop) - (1 << lhs_start)
139+
value &= ~mask
140+
value |= (rhs << lhs_start) & mask
141+
value &= (1 << len(lhs)) - 1
142+
if lhs._signed and (value & (1 << (len(lhs) - 1))):
143+
value |= -1 << (len(lhs) - 1)
144+
sim.slots[slot].set(value)
145+
elif isinstance(lhs, Slice):
146+
_eval_assign_inner(sim, lhs.value, lhs_start + lhs.start, rhs, rhs_len)
147+
elif isinstance(lhs, Concat):
148+
part_stop = 0
149+
for part in lhs.parts:
150+
part_start = part_stop
151+
part_len = len(part)
152+
part_stop = part_start + part_len
153+
if lhs_start >= part_stop:
154+
continue
155+
if lhs_start + rhs_len <= part_start:
156+
continue
157+
if lhs_start < part_start:
158+
part_lhs_start = 0
159+
part_rhs_start = part_start - lhs_start
160+
else:
161+
part_lhs_start = lhs_start - part_start
162+
part_rhs_start = 0
163+
if lhs_start + rhs_len >= part_stop:
164+
part_rhs_len = part_stop - lhs_start - part_rhs_start
165+
else:
166+
part_rhs_len = rhs_len - part_rhs_start
167+
part_rhs = rhs >> part_rhs_start
168+
part_rhs &= (1 << part_rhs_len) - 1
169+
_eval_assign_inner(sim, part, part_lhs_start, part_rhs, part_rhs_len)
170+
elif isinstance(lhs, Part):
171+
offset = eval_value(sim, lhs.offset)
172+
offset *= lhs.stride
173+
_eval_assign_inner(sim, lhs.value, lhs_start + offset, rhs, rhs_len)
174+
elif isinstance(lhs, SwitchValue):
175+
test = eval_value(sim, lhs.test)
176+
for patterns, val in lhs.cases:
177+
if _eval_matches(test, patterns):
178+
_eval_assign_inner(sim, val, lhs_start, rhs, rhs_len)
179+
return
180+
else:
181+
raise ValueError(f"Value {lhs!r} cannot be assigned")
182+
183+
def eval_assign(sim, lhs, value):
184+
_eval_assign_inner(sim, lhs, 0, value, len(lhs))

tests/test_sim.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,30 @@ def process():
4343
with sim.write_vcd("test.vcd", "test.gtkw", traces=[*isigs, osig]):
4444
sim.run()
4545

46+
frag = Fragment()
47+
sim = Simulator(frag)
48+
def process():
49+
for isig, input in zip(isigs, inputs):
50+
yield isig.eq(input)
51+
yield Delay(0)
52+
if isinstance(stmt, Assign):
53+
yield stmt
54+
else:
55+
yield from stmt
56+
yield Delay(0)
57+
self.assertEqual((yield osig), output.value)
58+
sim.add_testbench(process)
59+
with sim.write_vcd("test.vcd", "test.gtkw", traces=[*isigs, osig]):
60+
sim.run()
61+
62+
4663
def test_invert(self):
4764
stmt = lambda y, a: y.eq(~a)
4865
self.assertStatement(stmt, [C(0b0000, 4)], C(0b1111, 4))
4966
self.assertStatement(stmt, [C(0b1010, 4)], C(0b0101, 4))
5067
self.assertStatement(stmt, [C(0, 4)], C(-1, 4))
68+
self.assertStatement(stmt, [C(0b0000, signed(4))], C(-1, signed(4)))
69+
self.assertStatement(stmt, [C(0b1010, signed(4))], C(0b0101, signed(4)))
5170

5271
def test_neg(self):
5372
stmt = lambda y, a: y.eq(-a)
@@ -126,6 +145,7 @@ def test_mul(self):
126145

127146
def test_floordiv(self):
128147
stmt = lambda y, a, b: y.eq(a // b)
148+
self.assertStatement(stmt, [C(2, 4), C(0, 4)], C(0, 8))
129149
self.assertStatement(stmt, [C(2, 4), C(1, 4)], C(2, 8))
130150
self.assertStatement(stmt, [C(2, 4), C(2, 4)], C(1, 8))
131151
self.assertStatement(stmt, [C(7, 4), C(2, 4)], C(3, 8))
@@ -285,6 +305,17 @@ def test_cat_lhs(self):
285305
stmt = lambda y, a: [Cat(l, m, n).eq(a), y.eq(Cat(n, m, l))]
286306
self.assertStatement(stmt, [C(0b100101110, 9)], C(0b110101100, 9))
287307

308+
def test_cat_slice_lhs(self):
309+
l = Signal(3)
310+
m = Signal(3)
311+
n = Signal(3)
312+
o = Signal(3)
313+
p = Signal(3)
314+
stmt = lambda y, a: [Cat(l, m, n, o, p).eq(-1), Cat(l, m, n, o, p)[4:11].eq(a), y.eq(Cat(p, o, n, m, l))]
315+
self.assertStatement(stmt, [C(0b0000000, 7)], C(0b111001000100111, 15))
316+
self.assertStatement(stmt, [C(0b1001011, 7)], C(0b111111010110111, 15))
317+
self.assertStatement(stmt, [C(0b1111111, 7)], C(0b111111111111111, 15))
318+
288319
def test_nested_cat_lhs(self):
289320
l = Signal(3)
290321
m = Signal(3)
@@ -327,6 +358,16 @@ def test_array_lhs(self):
327358
self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001))
328359
self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001))
329360

361+
def test_array_lhs_heterogenous(self):
362+
l = Signal(1, init=1)
363+
m = Signal(3, init=4)
364+
n = Signal(5, init=7)
365+
array = Array([l, m, n])
366+
stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))]
367+
self.assertStatement(stmt, [C(0), C(0b000)], C(0b001111000, 9))
368+
self.assertStatement(stmt, [C(1), C(0b010)], C(0b001110101, 9))
369+
self.assertStatement(stmt, [C(2), C(0b100)], C(0b001001001, 9))
370+
330371
def test_array_lhs_oob(self):
331372
l = Signal(3)
332373
m = Signal(3)

0 commit comments

Comments
 (0)