Skip to content

Commit b6c5294

Browse files
wanda-phiwhitequark
authored andcommitted
hdl.MemoryInstance: refactor and add first-class simulation support.
1 parent f4daf74 commit b6c5294

File tree

9 files changed

+527
-203
lines changed

9 files changed

+527
-203
lines changed

amaranth/back/rtlil.py

Lines changed: 41 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -787,69 +787,71 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
787787
return f"\\{fragment.type}", port_map, params
788788

789789
if isinstance(fragment, _mem.MemoryInstance):
790-
memory = fragment.memory
791-
init = "".join(format(_ast.Const(elem, _ast.unsigned(memory.width)).value, f"0{memory.width}b") for elem in reversed(memory.init))
792-
init = _ast.Const(int(init or "0", 2), memory.depth * memory.width)
790+
init = "".join(format(_ast.Const(elem, _ast.unsigned(fragment._width)).value, f"0{fragment._width}b") for elem in reversed(fragment._init))
791+
init = _ast.Const(int(init or "0", 2), fragment._depth * fragment._width)
793792
rd_clk = []
794793
rd_clk_enable = 0
795794
rd_clk_polarity = 0
796795
rd_transparency_mask = 0
797-
for index, port in enumerate(fragment.read_ports):
798-
if port.domain != "comb":
799-
cd = fragment.domains[port.domain]
796+
for index, port in enumerate(fragment._read_ports):
797+
if port._domain is not None:
798+
cd = fragment.domains[port._domain]
800799
rd_clk.append(cd.clk)
801800
if cd.clk_edge == "pos":
802801
rd_clk_polarity |= 1 << index
803802
rd_clk_enable |= 1 << index
804-
if port.transparent:
805-
for write_index, write_port in enumerate(fragment.write_ports):
806-
if port.domain == write_port.domain:
807-
rd_transparency_mask |= 1 << (index * len(fragment.write_ports) + write_index)
803+
for write_index in port._transparency:
804+
rd_transparency_mask |= 1 << (index * len(fragment._write_ports) + write_index)
808805
else:
809806
rd_clk.append(_ast.Const(0, 1))
810807
wr_clk = []
811808
wr_clk_enable = 0
812809
wr_clk_polarity = 0
813-
for index, port in enumerate(fragment.write_ports):
814-
cd = fragment.domains[port.domain]
810+
for index, port in enumerate(fragment._write_ports):
811+
cd = fragment.domains[port._domain]
815812
wr_clk.append(cd.clk)
816813
wr_clk_enable |= 1 << index
817814
if cd.clk_edge == "pos":
818815
wr_clk_polarity |= 1 << index
819816
params = {
820817
"MEMID": builder._make_name(hierarchy[-1], local=False),
821-
"SIZE": memory.depth,
818+
"SIZE": fragment._depth,
822819
"OFFSET": 0,
823-
"ABITS": _ast.Shape.cast(range(memory.depth)).width,
824-
"WIDTH": memory.width,
820+
"ABITS": _ast.Shape.cast(range(fragment._depth)).width,
821+
"WIDTH": fragment._width,
825822
"INIT": init,
826-
"RD_PORTS": len(fragment.read_ports),
827-
"RD_CLK_ENABLE": _ast.Const(rd_clk_enable, max(1, len(fragment.read_ports))),
828-
"RD_CLK_POLARITY": _ast.Const(rd_clk_polarity, max(1, len(fragment.read_ports))),
829-
"RD_TRANSPARENCY_MASK": _ast.Const(rd_transparency_mask, max(1, len(fragment.read_ports) * len(fragment.write_ports))),
830-
"RD_COLLISION_X_MASK": _ast.Const(0, max(1, len(fragment.read_ports) * len(fragment.write_ports))),
831-
"RD_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment.read_ports))),
832-
"RD_CE_OVER_SRST": _ast.Const(0, max(1, len(fragment.read_ports))),
833-
"RD_ARST_VALUE": _ast.Const(0, len(fragment.read_ports) * memory.width),
834-
"RD_SRST_VALUE": _ast.Const(0, len(fragment.read_ports) * memory.width),
835-
"RD_INIT_VALUE": _ast.Const(0, len(fragment.read_ports) * memory.width),
836-
"WR_PORTS": len(fragment.write_ports),
837-
"WR_CLK_ENABLE": _ast.Const(wr_clk_enable, max(1, len(fragment.write_ports))),
838-
"WR_CLK_POLARITY": _ast.Const(wr_clk_polarity, max(1, len(fragment.write_ports))),
839-
"WR_PRIORITY_MASK": _ast.Const(0, max(1, len(fragment.write_ports) * len(fragment.write_ports))),
840-
"WR_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment.write_ports))),
823+
"RD_PORTS": len(fragment._read_ports),
824+
"RD_CLK_ENABLE": _ast.Const(rd_clk_enable, max(1, len(fragment._read_ports))),
825+
"RD_CLK_POLARITY": _ast.Const(rd_clk_polarity, max(1, len(fragment._read_ports))),
826+
"RD_TRANSPARENCY_MASK": _ast.Const(rd_transparency_mask, max(1, len(fragment._read_ports) * len(fragment._write_ports))),
827+
"RD_COLLISION_X_MASK": _ast.Const(0, max(1, len(fragment._read_ports) * len(fragment._write_ports))),
828+
"RD_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._read_ports))),
829+
"RD_CE_OVER_SRST": _ast.Const(0, max(1, len(fragment._read_ports))),
830+
"RD_ARST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width),
831+
"RD_SRST_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width),
832+
"RD_INIT_VALUE": _ast.Const(0, len(fragment._read_ports) * fragment._width),
833+
"WR_PORTS": len(fragment._write_ports),
834+
"WR_CLK_ENABLE": _ast.Const(wr_clk_enable, max(1, len(fragment._write_ports))),
835+
"WR_CLK_POLARITY": _ast.Const(wr_clk_polarity, max(1, len(fragment._write_ports))),
836+
"WR_PRIORITY_MASK": _ast.Const(0, max(1, len(fragment._write_ports) * len(fragment._write_ports))),
837+
"WR_WIDE_CONTINUATION": _ast.Const(0, max(1, len(fragment._write_ports))),
841838
}
839+
def make_en(port):
840+
if len(port._data) == 0:
841+
return _ast.Const(0, 0)
842+
granularity = len(port._data) // len(port._en)
843+
return _ast.Cat(en_bit.replicate(granularity) for en_bit in port._en)
842844
port_map = {
843845
"\\RD_CLK": _ast.Cat(rd_clk),
844-
"\\RD_EN": _ast.Cat(port.en for port in fragment.read_ports),
845-
"\\RD_ARST": _ast.Const(0, len(fragment.read_ports)),
846-
"\\RD_SRST": _ast.Const(0, len(fragment.read_ports)),
847-
"\\RD_ADDR": _ast.Cat(port.addr for port in fragment.read_ports),
848-
"\\RD_DATA": _ast.Cat(port.data for port in fragment.read_ports),
846+
"\\RD_EN": _ast.Cat(port._en for port in fragment._read_ports),
847+
"\\RD_ARST": _ast.Const(0, len(fragment._read_ports)),
848+
"\\RD_SRST": _ast.Const(0, len(fragment._read_ports)),
849+
"\\RD_ADDR": _ast.Cat(port._addr for port in fragment._read_ports),
850+
"\\RD_DATA": _ast.Cat(port._data for port in fragment._read_ports),
849851
"\\WR_CLK": _ast.Cat(wr_clk),
850-
"\\WR_EN": _ast.Cat(_ast.Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in fragment.write_ports),
851-
"\\WR_ADDR": _ast.Cat(port.addr for port in fragment.write_ports),
852-
"\\WR_DATA": _ast.Cat(port.data for port in fragment.write_ports),
852+
"\\WR_EN": _ast.Cat(make_en(port) for port in fragment._write_ports),
853+
"\\WR_ADDR": _ast.Cat(port._addr for port in fragment._write_ports),
854+
"\\WR_DATA": _ast.Cat(port._data for port in fragment._write_ports),
853855
}
854856
return "$mem_v2", port_map, params
855857

@@ -913,7 +915,7 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
913915
if isinstance(subfragment, _ir.Instance):
914916
src = _src(subfragment.src_loc)
915917
elif isinstance(subfragment, _mem.MemoryInstance):
916-
src = _src(subfragment.memory.src_loc)
918+
src = _src(subfragment._src_loc)
917919
else:
918920
src = ""
919921

amaranth/hdl/_ir.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def _merge_subfragment(self, subfragment):
183183

184184
def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"):
185185
assert mode in ("silent", "warn", "error")
186+
from ._mem import MemoryInstance
186187

187188
driver_subfrags = SignalDict()
188189
def add_subfrag(registry, entity, entry):
@@ -214,7 +215,7 @@ def add_subfrag(registry, entity, entry):
214215
# Always flatten subfragments that explicitly request it.
215216
flatten_subfrags.add((subfrag, subfrag_hierarchy))
216217

217-
if isinstance(subfrag, Instance):
218+
if isinstance(subfrag, (Instance, MemoryInstance)):
218219
# Never flatten instances.
219220
continue
220221

@@ -368,6 +369,8 @@ def _propagate_domains(self, missing_domain, *, platform=None):
368369
return new_domains
369370

370371
def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
372+
from ._mem import MemoryInstance
373+
371374
def add_uses(*sigs, self=self):
372375
for sig in flatten(sigs):
373376
if sig not in uses:
@@ -416,6 +419,22 @@ def add_io(*sigs):
416419
if dir == "io":
417420
subfrag.add_ports(value._lhs_signals(), dir=dir)
418421
add_io(value._lhs_signals())
422+
elif isinstance(subfrag, MemoryInstance):
423+
for port in subfrag._read_ports:
424+
subfrag.add_ports(port._data._lhs_signals(), dir="o")
425+
add_defs(port._data._lhs_signals())
426+
for value in [port._addr, port._en]:
427+
subfrag.add_ports(value._rhs_signals(), dir="i")
428+
add_uses(value._rhs_signals())
429+
for port in subfrag._write_ports:
430+
for value in [port._addr, port._en, port._data]:
431+
subfrag.add_ports(value._rhs_signals(), dir="i")
432+
add_uses(value._rhs_signals())
433+
for domain, _ in subfrag.iter_sync():
434+
cd = subfrag.domains[domain]
435+
add_uses(cd.clk)
436+
if cd.rst is not None:
437+
add_uses(cd.rst)
419438
else:
420439
parent[subfrag] = self
421440
level [subfrag] = level[self] + 1

amaranth/hdl/_mem.py

Lines changed: 108 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,104 @@
33

44
from .. import tracer
55
from ._ast import *
6-
from ._ir import Elaboratable, Instance, Fragment
6+
from ._ir import Elaboratable, Fragment
7+
from ..utils import ceil_log2
78

89

910
__all__ = ["Memory", "ReadPort", "WritePort", "DummyPort"]
1011

1112

13+
class MemoryIdentity: pass
14+
15+
16+
class MemorySimRead:
17+
def __init__(self, identity, addr):
18+
assert isinstance(identity, MemoryIdentity)
19+
self._identity = identity
20+
self._addr = Value.cast(addr)
21+
22+
def eq(self, value):
23+
return MemorySimWrite(self._identity, self._addr, value)
24+
25+
26+
class MemorySimWrite:
27+
def __init__(self, identity, addr, data):
28+
assert isinstance(identity, MemoryIdentity)
29+
self._identity = identity
30+
self._addr = Value.cast(addr)
31+
self._data = Value.cast(data)
32+
33+
34+
class MemoryInstance(Fragment):
35+
class _ReadPort:
36+
def __init__(self, *, domain, addr, data, en, transparency):
37+
assert domain is None or isinstance(domain, str)
38+
if domain == "comb":
39+
domain = None
40+
self._domain = domain
41+
self._addr = Value.cast(addr)
42+
self._data = Value.cast(data)
43+
self._en = Value.cast(en)
44+
self._transparency = tuple(transparency)
45+
assert len(self._en) == 1
46+
if domain is None:
47+
assert isinstance(self._en, Const)
48+
assert self._en.width == 1
49+
assert self._en.value == 1
50+
51+
class _WritePort:
52+
def __init__(self, *, domain, addr, data, en):
53+
assert isinstance(domain, str)
54+
assert domain != "comb"
55+
self._domain = domain
56+
self._addr = Value.cast(addr)
57+
self._data = Value.cast(data)
58+
self._en = Value.cast(en)
59+
if len(self._data):
60+
assert len(self._data) % len(self._en) == 0
61+
62+
@property
63+
def _granularity(self):
64+
if not len(self._data):
65+
return 1
66+
return len(self._data) // len(self._en)
67+
68+
69+
def __init__(self, *, identity, width, depth, init=None, attrs=None, src_loc=None):
70+
super().__init__()
71+
assert isinstance(identity, MemoryIdentity)
72+
self._identity = identity
73+
self._width = operator.index(width)
74+
self._depth = operator.index(depth)
75+
self._init = tuple(init) if init is not None else ()
76+
assert len(self._init) <= self._depth
77+
self._init += (0,) * (self._depth - len(self._init))
78+
for x in self._init:
79+
assert isinstance(x, int)
80+
self._attrs = attrs or {}
81+
self._src_loc = src_loc
82+
self._read_ports = []
83+
self._write_ports = []
84+
85+
def read_port(self, *, domain, addr, data, en, transparency):
86+
port = self._ReadPort(domain=domain, addr=addr, data=data, en=en, transparency=transparency)
87+
assert len(port._data) == self._width
88+
assert len(port._addr) == ceil_log2(self._depth)
89+
for x in port._transparency:
90+
assert isinstance(x, int)
91+
assert x in range(len(self._write_ports))
92+
for signal in port._data._rhs_signals():
93+
self.add_driver(signal, port._domain)
94+
self._read_ports.append(port)
95+
96+
def write_port(self, *, domain, addr, data, en):
97+
port = self._WritePort(domain=domain, addr=addr, data=data, en=en)
98+
assert len(port._data) == self._width
99+
assert len(port._addr) == ceil_log2(self._depth)
100+
self._write_ports.append(port)
101+
return len(self._write_ports) - 1
102+
103+
12104
class Memory(Elaboratable):
13105
"""A word addressable storage.
14106
@@ -50,16 +142,10 @@ def __init__(self, *, width, depth, init=None, name=None, attrs=None, simulate=T
50142
self.depth = depth
51143
self.attrs = OrderedDict(() if attrs is None else attrs)
52144

53-
# Array of signals for simulation.
54-
self._array = Array()
55-
if simulate:
56-
for addr in range(self.depth):
57-
self._array.append(Signal(self.width, name="{}({})"
58-
.format(name or "memory", addr)))
59-
60145
self.init = init
61146
self._read_ports = []
62147
self._write_ports = []
148+
self._identity = MemoryIdentity()
63149

64150
@property
65151
def init(self):
@@ -73,11 +159,8 @@ def init(self, new_init):
73159
.format(len(self.init), self.depth))
74160

75161
try:
76-
for addr in range(len(self._array)):
77-
if addr < len(self._init):
78-
self._array[addr].reset = operator.index(self._init[addr])
79-
else:
80-
self._array[addr].reset = 0
162+
for addr, val in enumerate(self._init):
163+
operator.index(val)
81164
except TypeError as e:
82165
raise TypeError("Memory initialization value at address {:x}: {}"
83166
.format(addr, e)) from None
@@ -116,52 +199,24 @@ def write_port(self, *, src_loc_at=0, **kwargs):
116199

117200
def __getitem__(self, index):
118201
"""Simulation only."""
119-
return self._array[index]
202+
return MemorySimRead(self._identity, index)
120203

121204
def elaborate(self, platform):
122-
f = MemoryInstance(self, self._read_ports, self._write_ports)
205+
f = MemoryInstance(identity=self._identity, width=self.width, depth=self.depth, init=self.init, attrs=self.attrs, src_loc=self.src_loc)
206+
write_ports = {}
207+
for port in self._write_ports:
208+
port._MustUse__used = True
209+
iport = f.write_port(domain=port.domain, addr=port.addr, data=port.data, en=port.en)
210+
write_ports.setdefault(port.domain, []).append(iport)
123211
for port in self._read_ports:
124212
port._MustUse__used = True
125213
if port.domain == "comb":
126-
# Asynchronous port
127-
f.add_statements(None, port.data.eq(self._array[port.addr]))
128-
f.add_driver(port.data)
129-
else:
130-
# Synchronous port
131-
data = self._array[port.addr]
132-
for write_port in self._write_ports:
133-
if port.domain == write_port.domain and port.transparent:
134-
if len(write_port.en) > 1:
135-
parts = []
136-
for index, en_bit in enumerate(write_port.en):
137-
offset = index * write_port.granularity
138-
bits = slice(offset, offset + write_port.granularity)
139-
cond = en_bit & (port.addr == write_port.addr)
140-
parts.append(Mux(cond, write_port.data[bits], data[bits]))
141-
data = Cat(parts)
142-
else:
143-
cond = write_port.en & (port.addr == write_port.addr)
144-
data = Mux(cond, write_port.data, data)
145-
f.add_statements(
146-
port.domain,
147-
Switch(port.en, {
148-
1: port.data.eq(data)
149-
})
150-
)
151-
f.add_driver(port.data, port.domain)
152-
for port in self._write_ports:
153-
port._MustUse__used = True
154-
if len(port.en) > 1:
155-
for index, en_bit in enumerate(port.en):
156-
offset = index * port.granularity
157-
bits = slice(offset, offset + port.granularity)
158-
write_data = self._array[port.addr][bits].eq(port.data[bits])
159-
f.add_statements(port.domain, Switch(en_bit, { 1: write_data }))
214+
f.read_port(domain="comb", addr=port.addr, data=port.data, en=Const(1), transparency=())
160215
else:
161-
write_data = self._array[port.addr].eq(port.data)
162-
f.add_statements(port.domain, Switch(port.en, { 1: write_data }))
163-
for signal in self._array:
164-
f.add_driver(signal, port.domain)
216+
transparency = []
217+
if port.transparent:
218+
transparency = write_ports.get(port.domain, [])
219+
f.read_port(domain=port.domain, addr=port.addr, data=port.data, en=port.en, transparency=transparency)
165220
return f
166221

167222

@@ -308,12 +363,3 @@ def __init__(self, *, data_width, addr_width, domain="sync", name=None, granular
308363
name=f"{name}_data", src_loc_at=1)
309364
self.en = Signal(data_width // granularity,
310365
name=f"{name}_en", src_loc_at=1)
311-
312-
313-
class MemoryInstance(Fragment):
314-
def __init__(self, memory, read_ports, write_ports):
315-
super().__init__()
316-
self.memory = memory
317-
self.read_ports = read_ports
318-
self.write_ports = write_ports
319-
self.attrs = memory.attrs

0 commit comments

Comments
 (0)