Skip to content

Commit 8c4a15a

Browse files
wanda-phiwhitequark
authored andcommitted
hdl.mem: lower Memory directly to $mem_v2 RTLIL cell.
The design decision of using split memory ports in the internal representation (copied from Yosys) was misguided and caused no end of misery. Remove any uses of `$memrd`/`$memwr` and lower memories directly to a combined memory cell, currently the RTLIL one.
1 parent fc85feb commit 8c4a15a

File tree

10 files changed

+183
-193
lines changed

10 files changed

+183
-193
lines changed

amaranth/back/rtlil.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -863,49 +863,21 @@ def _convert_fragment(builder, fragment, name_map, hierarchy):
863863
if sub_name is None:
864864
sub_name = module.anonymous()
865865

866-
sub_params = OrderedDict()
867-
if hasattr(subfragment, "parameters"):
868-
for param_name, param_value in subfragment.parameters.items():
869-
if isinstance(param_value, mem.Memory):
870-
memory = param_value
871-
if memory not in memories:
872-
memories[memory] = module.memory(width=memory.width, size=memory.depth,
873-
name=memory.name, attrs=memory.attrs)
874-
addr_bits = bits_for(memory.depth)
875-
data_parts = []
876-
data_mask = (1 << memory.width) - 1
877-
for addr in range(memory.depth):
878-
if addr < len(memory.init):
879-
data = memory.init[addr] & data_mask
880-
else:
881-
data = 0
882-
data_parts.append("{:0{}b}".format(data, memory.width))
883-
module.cell("$meminit", ports={
884-
"\\ADDR": rhs_compiler(ast.Const(0, addr_bits)),
885-
"\\DATA": "{}'".format(memory.width * memory.depth) +
886-
"".join(reversed(data_parts)),
887-
}, params={
888-
"MEMID": memories[memory],
889-
"ABITS": addr_bits,
890-
"WIDTH": memory.width,
891-
"WORDS": memory.depth,
892-
"PRIORITY": 0,
893-
})
894-
895-
param_value = memories[memory]
896-
897-
sub_params[param_name] = param_value
866+
sub_params = OrderedDict(getattr(subfragment, "parameters", {}))
898867

899868
sub_type, sub_port_map = \
900869
_convert_fragment(builder, subfragment, name_map,
901870
hierarchy=hierarchy + (sub_name,))
902871

872+
if sub_type == "$mem_v2" and "MEMID" not in sub_params:
873+
sub_params["MEMID"] = "$" + sub_name
874+
903875
sub_ports = OrderedDict()
904876
for port, value in sub_port_map.items():
905877
if not isinstance(subfragment, ir.Instance):
906878
for signal in value._rhs_signals():
907879
compiler_state.resolve_curr(signal, prefix=sub_name)
908-
if len(value) > 0:
880+
if len(value) > 0 or sub_type == "$mem_v2":
909881
sub_ports[port] = rhs_compiler(value)
910882

911883
module.cell(sub_type, name=sub_name, ports=sub_ports, params=sub_params,

amaranth/compat/fhdl/specials.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,6 @@ def __init__(self, adr, dat_r, we=None, dat_w=None, async_read=False, re=None,
8383
self.clock = ClockSignal(clock_domain)
8484

8585

86-
@extend(NativeMemory)
87-
@deprecated("it is not necessary or permitted to add Memory as a special or submodule")
88-
def elaborate(self, platform):
89-
return Fragment()
90-
91-
9286
class CompatMemory(NativeMemory, Elaboratable):
9387
def __init__(self, width, depth, init=None, name=None):
9488
super().__init__(width=width, depth=depth, init=init, name=name)

amaranth/hdl/ir.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,6 @@ def _resolve_hierarchy_conflicts(self, hierarchy=("top",), mode="warn"):
181181
assert mode in ("silent", "warn", "error")
182182

183183
driver_subfrags = SignalDict()
184-
memory_subfrags = OrderedDict()
185184
def add_subfrag(registry, entity, entry):
186185
# Because of missing domain insertion, at the point when this code runs, we have
187186
# a mixture of bound and unbound {Clock,Reset}Signals. Map the bound ones to
@@ -212,24 +211,16 @@ def add_subfrag(registry, entity, entry):
212211
flatten_subfrags.add((subfrag, subfrag_hierarchy))
213212

214213
if isinstance(subfrag, Instance):
215-
# For memories (which are subfragments, but semantically a part of superfragment),
216-
# record that this fragment is driving it.
217-
if subfrag.type in ("$memrd", "$memwr"):
218-
memory = subfrag.parameters["MEMID"]
219-
add_subfrag(memory_subfrags, memory, (None, hierarchy))
220-
221214
# Never flatten instances.
222215
continue
223216

224217
# First, recurse into subfragments and let them detect driver conflicts as well.
225-
subfrag_drivers, subfrag_memories = \
218+
subfrag_drivers = \
226219
subfrag._resolve_hierarchy_conflicts(subfrag_hierarchy, mode)
227220

228-
# Second, classify subfragments by signals they drive and memories they use.
221+
# Second, classify subfragments by signals they drive.
229222
for signal in subfrag_drivers:
230223
add_subfrag(driver_subfrags, signal, (subfrag, subfrag_hierarchy))
231-
for memory in subfrag_memories:
232-
add_subfrag(memory_subfrags, memory, (subfrag, subfrag_hierarchy))
233224

234225
# Find out the set of subfragments that needs to be flattened into this fragment
235226
# to resolve driver-driver conflicts.
@@ -253,20 +244,6 @@ def flatten_subfrags_if_needed(subfrags):
253244
message += "; hierarchy will be flattened"
254245
warnings.warn_explicit(message, DriverConflict, *signal.src_loc)
255246

256-
for memory, subfrags in memory_subfrags.items():
257-
subfrag_names = flatten_subfrags_if_needed(subfrags)
258-
if not subfrag_names:
259-
continue
260-
261-
# While we're at it, show a message.
262-
message = ("Memory '{}' is accessed from multiple fragments: {}"
263-
.format(memory.name, ", ".join(subfrag_names)))
264-
if mode == "error":
265-
raise DriverConflict(message)
266-
elif mode == "warn":
267-
message += "; hierarchy will be flattened"
268-
warnings.warn_explicit(message, DriverConflict, *memory.src_loc)
269-
270247
# Flatten hierarchy.
271248
for subfrag, subfrag_hierarchy in sorted(flatten_subfrags, key=lambda x: x[1]):
272249
self._merge_subfragment(subfrag)
@@ -282,8 +259,7 @@ def flatten_subfrags_if_needed(subfrags):
282259
return self._resolve_hierarchy_conflicts(hierarchy, mode)
283260

284261
# Nothing was flattened, we're done!
285-
return (SignalSet(driver_subfrags.keys()),
286-
set(memory_subfrags.keys()))
262+
return SignalSet(driver_subfrags.keys())
287263

288264
def _propagate_domains_up(self, hierarchy=("top",)):
289265
from .xfrm import DomainRenamer

amaranth/hdl/mem.py

Lines changed: 106 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33

44
from .. import tracer
55
from .ast import *
6-
from .ir import Elaboratable, Instance
6+
from .ir import Elaboratable, Instance, Fragment
77

88

99
__all__ = ["Memory", "ReadPort", "WritePort", "DummyPort"]
1010

1111

12-
class Memory:
12+
class Memory(Elaboratable):
1313
"""A word addressable storage.
1414
1515
Parameters
@@ -58,6 +58,8 @@ def __init__(self, *, width, depth, init=None, name=None, attrs=None, simulate=T
5858
.format(name or "memory", addr)))
5959

6060
self.init = init
61+
self._read_ports = []
62+
self._write_ports = []
6163

6264
@property
6365
def init(self):
@@ -116,6 +118,96 @@ def __getitem__(self, index):
116118
"""Simulation only."""
117119
return self._array[index]
118120

121+
def elaborate(self, platform):
122+
init = "".join(format(Const(elem, unsigned(self.width)).value, f"0{self.width}b") for elem in reversed(self.init))
123+
init = Const(int(init or "0", 2), len(self.init) * self.width)
124+
rd_clk = []
125+
rd_clk_enable = 0
126+
rd_transparency_mask = 0
127+
for index, port in enumerate(self._read_ports):
128+
if port.domain != "comb":
129+
rd_clk.append(ClockSignal(port.domain))
130+
rd_clk_enable |= 1 << index
131+
if port.transparent:
132+
for write_index, write_port in enumerate(self._write_ports):
133+
if port.domain == write_port.domain:
134+
rd_transparency_mask |= 1 << (index * len(self._write_ports) + write_index)
135+
else:
136+
rd_clk.append(Const(0, 1))
137+
f = Instance("$mem_v2",
138+
*(("a", attr, value) for attr, value in self.attrs.items()),
139+
p_SIZE=self.depth,
140+
p_OFFSET=0,
141+
p_ABITS=Shape.cast(range(self.depth)).width,
142+
p_WIDTH=self.width,
143+
p_INIT=init,
144+
p_RD_PORTS=len(self._read_ports),
145+
p_RD_CLK_ENABLE=Const(rd_clk_enable, len(self._read_ports)) if self._read_ports else Const(0, 1),
146+
p_RD_CLK_POLARITY=Const(-1, unsigned(len(self._read_ports))) if self._read_ports else Const(0, 1),
147+
p_RD_TRANSPARENCY_MASK=Const(rd_transparency_mask, max(1, len(self._read_ports) * len(self._write_ports))),
148+
p_RD_COLLISION_X_MASK=Const(0, max(1, len(self._read_ports) * len(self._write_ports))),
149+
p_RD_WIDE_CONTINUATION=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1),
150+
p_RD_CE_OVER_SRST=Const(0, len(self._read_ports)) if self._read_ports else Const(0, 1),
151+
p_RD_ARST_VALUE=Const(0, len(self._read_ports) * self.width),
152+
p_RD_SRST_VALUE=Const(0, len(self._read_ports) * self.width),
153+
p_RD_INIT_VALUE=Const(0, len(self._read_ports) * self.width),
154+
p_WR_PORTS=len(self._write_ports),
155+
p_WR_CLK_ENABLE=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1),
156+
p_WR_CLK_POLARITY=Const(-1, unsigned(len(self._write_ports))) if self._write_ports else Const(0, 1),
157+
p_WR_PRIORITY_MASK=Const(0, len(self._write_ports) * len(self._write_ports)) if self._write_ports else Const(0, 1),
158+
p_WR_WIDE_CONTINUATION=Const(0, len(self._write_ports)) if self._write_ports else Const(0, 1),
159+
i_RD_CLK=Cat(rd_clk),
160+
i_RD_EN=Cat(port.en for port in self._read_ports),
161+
i_RD_ARST=Const(0, len(self._read_ports)),
162+
i_RD_SRST=Const(0, len(self._read_ports)),
163+
i_RD_ADDR=Cat(port.addr for port in self._read_ports),
164+
o_RD_DATA=Cat(port.data for port in self._read_ports),
165+
i_WR_CLK=Cat(ClockSignal(port.domain) for port in self._write_ports),
166+
i_WR_EN=Cat(Cat(en_bit.replicate(port.granularity) for en_bit in port.en) for port in self._write_ports),
167+
i_WR_ADDR=Cat(port.addr for port in self._write_ports),
168+
i_WR_DATA=Cat(port.data for port in self._write_ports),
169+
)
170+
for port in self._read_ports:
171+
port._MustUse__used = True
172+
if port.domain == "comb":
173+
# Asynchronous port
174+
f.add_statements(port.data.eq(self._array[port.addr]))
175+
f.add_driver(port.data)
176+
else:
177+
# Synchronous port
178+
data = self._array[port.addr]
179+
for write_port in self._write_ports:
180+
if port.domain == write_port.domain and port.transparent:
181+
if len(write_port.en) > 1:
182+
parts = []
183+
for index, en_bit in enumerate(write_port.en):
184+
offset = index * write_port.granularity
185+
bits = slice(offset, offset + write_port.granularity)
186+
cond = en_bit & (port.addr == write_port.addr)
187+
parts.append(Mux(cond, write_port.data[bits], data[bits]))
188+
data = Cat(parts)
189+
else:
190+
data = Mux(write_port.en, write_port.data, data)
191+
f.add_statements(
192+
Switch(port.en, {
193+
1: port.data.eq(data)
194+
})
195+
)
196+
f.add_driver(port.data, port.domain)
197+
for port in self._write_ports:
198+
port._MustUse__used = True
199+
if len(port.en) > 1:
200+
for index, en_bit in enumerate(port.en):
201+
offset = index * port.granularity
202+
bits = slice(offset, offset + port.granularity)
203+
write_data = self._array[port.addr][bits].eq(port.data[bits])
204+
f.add_statements(Switch(en_bit, { 1: write_data }))
205+
else:
206+
write_data = self._array[port.addr].eq(port.data)
207+
f.add_statements(Switch(port.en, { 1: write_data }))
208+
for signal in self._array:
209+
f.add_driver(signal, port.domain)
210+
return f
119211

120212
class ReadPort(Elaboratable):
121213
"""A memory read port.
@@ -142,9 +234,7 @@ class ReadPort(Elaboratable):
142234
data : Signal(memory.width), out
143235
Read data.
144236
en : Signal or Const, in
145-
Read enable. If asserted, ``data`` is updated with the word stored at ``addr``. Note that
146-
transparent ports cannot assign ``en`` (which is hardwired to 1 instead), as doing so is
147-
currently not supported by Yosys.
237+
Read enable. If asserted, ``data`` is updated with the word stored at ``addr``.
148238
149239
Exceptions
150240
----------
@@ -162,59 +252,19 @@ def __init__(self, memory, *, domain="sync", transparent=True, src_loc_at=0):
162252
name="{}_r_addr".format(memory.name), src_loc_at=1 + src_loc_at)
163253
self.data = Signal(memory.width,
164254
name="{}_r_data".format(memory.name), src_loc_at=1 + src_loc_at)
165-
if self.domain != "comb" and not transparent:
255+
if self.domain != "comb":
166256
self.en = Signal(name="{}_r_en".format(memory.name), reset=1,
167257
src_loc_at=1 + src_loc_at)
168258
else:
169259
self.en = Const(1)
170260

261+
memory._read_ports.append(self)
262+
171263
def elaborate(self, platform):
172-
f = Instance("$memrd",
173-
p_MEMID=self.memory,
174-
p_ABITS=self.addr.width,
175-
p_WIDTH=self.data.width,
176-
p_CLK_ENABLE=self.domain != "comb",
177-
p_CLK_POLARITY=1,
178-
p_TRANSPARENT=self.transparent,
179-
i_CLK=ClockSignal(self.domain) if self.domain != "comb" else Const(0),
180-
i_EN=self.en,
181-
i_ADDR=self.addr,
182-
o_DATA=self.data,
183-
)
184-
if self.domain == "comb":
185-
# Asynchronous port
186-
f.add_statements(self.data.eq(self.memory._array[self.addr]))
187-
f.add_driver(self.data)
188-
elif not self.transparent:
189-
# Synchronous, read-before-write port
190-
f.add_statements(
191-
Switch(self.en, {
192-
1: self.data.eq(self.memory._array[self.addr])
193-
})
194-
)
195-
f.add_driver(self.data, self.domain)
264+
if self is self.memory._read_ports[0]:
265+
return self.memory
196266
else:
197-
# Synchronous, write-through port
198-
# This model is a bit unconventional. We model transparent ports as asynchronous ports
199-
# that are latched when the clock is high. This isn't exactly correct, but it is very
200-
# close to the correct behavior of a transparent port, and the difference should only
201-
# be observable in pathological cases of clock gating. A register is injected to
202-
# the address input to achieve the correct address-to-data latency. Also, the reset
203-
# value of the data output is forcibly set to the 0th initial value, if any--note that
204-
# many FPGAs do not guarantee this behavior!
205-
if len(self.memory.init) > 0:
206-
self.data.reset = operator.index(self.memory.init[0])
207-
latch_addr = Signal.like(self.addr)
208-
f.add_statements(
209-
latch_addr.eq(self.addr),
210-
Switch(ClockSignal(self.domain), {
211-
0: self.data.eq(self.data),
212-
1: self.data.eq(self.memory._array[latch_addr]),
213-
}),
214-
)
215-
f.add_driver(latch_addr, self.domain)
216-
f.add_driver(self.data)
217-
return f
267+
return Fragment()
218268

219269

220270
class WritePort(Elaboratable):
@@ -272,31 +322,13 @@ def __init__(self, memory, *, domain="sync", granularity=None, src_loc_at=0):
272322
self.en = Signal(memory.width // granularity,
273323
name="{}_w_en".format(memory.name), src_loc_at=1 + src_loc_at)
274324

325+
memory._write_ports.append(self)
326+
275327
def elaborate(self, platform):
276-
f = Instance("$memwr",
277-
p_MEMID=self.memory,
278-
p_ABITS=self.addr.width,
279-
p_WIDTH=self.data.width,
280-
p_CLK_ENABLE=1,
281-
p_CLK_POLARITY=1,
282-
p_PRIORITY=0,
283-
i_CLK=ClockSignal(self.domain),
284-
i_EN=Cat(en_bit.replicate(self.granularity) for en_bit in self.en),
285-
i_ADDR=self.addr,
286-
i_DATA=self.data,
287-
)
288-
if len(self.en) > 1:
289-
for index, en_bit in enumerate(self.en):
290-
offset = index * self.granularity
291-
bits = slice(offset, offset + self.granularity)
292-
write_data = self.memory._array[self.addr][bits].eq(self.data[bits])
293-
f.add_statements(Switch(en_bit, { 1: write_data }))
328+
if not self.memory._read_ports and self is self.memory._write_ports[0]:
329+
return self.memory
294330
else:
295-
write_data = self.memory._array[self.addr].eq(self.data)
296-
f.add_statements(Switch(self.en, { 1: write_data }))
297-
for signal in self.memory._array:
298-
f.add_driver(signal, self.domain)
299-
return f
331+
return Fragment()
300332

301333

302334
class DummyPort:

amaranth/hdl/xfrm.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -720,10 +720,14 @@ def _insert_control(self, fragment, domain, signals):
720720

721721
def on_fragment(self, fragment):
722722
new_fragment = super().on_fragment(fragment)
723-
if isinstance(new_fragment, Instance) and new_fragment.type in ("$memrd", "$memwr"):
724-
clk_port, clk_dir = new_fragment.named_ports["CLK"]
725-
if isinstance(clk_port, ClockSignal) and clk_port.domain in self.controls:
726-
en_port, en_dir = new_fragment.named_ports["EN"]
727-
en_port = Mux(self.controls[clk_port.domain], en_port, Const(0, len(en_port)))
728-
new_fragment.named_ports["EN"] = en_port, en_dir
723+
if isinstance(new_fragment, Instance) and new_fragment.type == "$mem_v2":
724+
for kind in ["RD", "WR"]:
725+
clk_parts = new_fragment.named_ports[kind + "_CLK"][0].parts
726+
en_parts = new_fragment.named_ports[kind + "_EN"][0].parts
727+
new_en = []
728+
for clk, en in zip(clk_parts, en_parts):
729+
if isinstance(clk, ClockSignal) and clk.domain in self.controls:
730+
en = Mux(self.controls[clk.domain], en, Const(0, len(en)))
731+
new_en.append(en)
732+
new_fragment.named_ports[kind + "_EN"] = Cat(new_en), "i"
729733
return new_fragment

0 commit comments

Comments
 (0)