|
| 1 | +# TODO(amaranth-0.6): remove module |
| 2 | + |
| 3 | +from enum import Enum |
| 4 | +from collections import OrderedDict |
| 5 | +from functools import reduce, wraps |
| 6 | + |
| 7 | +from .. import tracer |
| 8 | +from .._utils import union |
| 9 | +from .ast import * |
| 10 | + |
| 11 | + |
| 12 | +__all__ = ["Direction", "DIR_NONE", "DIR_FANOUT", "DIR_FANIN", "Layout", "Record"] |
| 13 | + |
| 14 | + |
| 15 | +Direction = Enum('Direction', ('NONE', 'FANOUT', 'FANIN')) |
| 16 | + |
| 17 | +DIR_NONE = Direction.NONE |
| 18 | +DIR_FANOUT = Direction.FANOUT |
| 19 | +DIR_FANIN = Direction.FANIN |
| 20 | + |
| 21 | + |
| 22 | +class Layout: |
| 23 | + @staticmethod |
| 24 | + def cast(obj, *, src_loc_at=0): |
| 25 | + if isinstance(obj, Layout): |
| 26 | + return obj |
| 27 | + return Layout(obj, src_loc_at=1 + src_loc_at) |
| 28 | + |
| 29 | + def __init__(self, fields, *, src_loc_at=0): |
| 30 | + self.fields = OrderedDict() |
| 31 | + for field in fields: |
| 32 | + if not isinstance(field, tuple) or len(field) not in (2, 3): |
| 33 | + raise TypeError("Field {!r} has invalid layout: should be either " |
| 34 | + "(name, shape) or (name, shape, direction)" |
| 35 | + .format(field)) |
| 36 | + if len(field) == 2: |
| 37 | + name, shape = field |
| 38 | + direction = DIR_NONE |
| 39 | + if isinstance(shape, list): |
| 40 | + shape = Layout.cast(shape) |
| 41 | + else: |
| 42 | + name, shape, direction = field |
| 43 | + if not isinstance(direction, Direction): |
| 44 | + raise TypeError("Field {!r} has invalid direction: should be a Direction " |
| 45 | + "instance like DIR_FANIN" |
| 46 | + .format(field)) |
| 47 | + if not isinstance(name, str): |
| 48 | + raise TypeError("Field {!r} has invalid name: should be a string" |
| 49 | + .format(field)) |
| 50 | + if not isinstance(shape, Layout): |
| 51 | + try: |
| 52 | + # Check provided shape by calling Shape.cast and checking for exception |
| 53 | + Shape.cast(shape, src_loc_at=1 + src_loc_at) |
| 54 | + except Exception: |
| 55 | + raise TypeError("Field {!r} has invalid shape: should be castable to Shape " |
| 56 | + "or a list of fields of a nested record" |
| 57 | + .format(field)) |
| 58 | + if name in self.fields: |
| 59 | + raise NameError("Field {!r} has a name that is already present in the layout" |
| 60 | + .format(field)) |
| 61 | + self.fields[name] = (shape, direction) |
| 62 | + |
| 63 | + def __getitem__(self, item): |
| 64 | + if isinstance(item, tuple): |
| 65 | + return Layout([ |
| 66 | + (name, shape, dir) |
| 67 | + for (name, (shape, dir)) in self.fields.items() |
| 68 | + if name in item |
| 69 | + ]) |
| 70 | + |
| 71 | + return self.fields[item] |
| 72 | + |
| 73 | + def __iter__(self): |
| 74 | + for name, (shape, dir) in self.fields.items(): |
| 75 | + yield (name, shape, dir) |
| 76 | + |
| 77 | + def __eq__(self, other): |
| 78 | + return self.fields == other.fields |
| 79 | + |
| 80 | + def __repr__(self): |
| 81 | + field_reprs = [] |
| 82 | + for name, shape, dir in self: |
| 83 | + if dir == DIR_NONE: |
| 84 | + field_reprs.append("({!r}, {!r})".format(name, shape)) |
| 85 | + else: |
| 86 | + field_reprs.append("({!r}, {!r}, Direction.{})".format(name, shape, dir.name)) |
| 87 | + return "Layout([{}])".format(", ".join(field_reprs)) |
| 88 | + |
| 89 | + |
| 90 | +class Record(ValueCastable): |
| 91 | + @staticmethod |
| 92 | + def like(other, *, name=None, name_suffix=None, src_loc_at=0): |
| 93 | + if name is not None: |
| 94 | + new_name = str(name) |
| 95 | + elif name_suffix is not None: |
| 96 | + new_name = other.name + str(name_suffix) |
| 97 | + else: |
| 98 | + new_name = tracer.get_var_name(depth=2 + src_loc_at, default=None) |
| 99 | + |
| 100 | + def concat(a, b): |
| 101 | + if a is None: |
| 102 | + return b |
| 103 | + return "{}__{}".format(a, b) |
| 104 | + |
| 105 | + fields = {} |
| 106 | + for field_name in other.fields: |
| 107 | + field = other[field_name] |
| 108 | + if isinstance(field, Record): |
| 109 | + fields[field_name] = Record.like(field, name=concat(new_name, field_name), |
| 110 | + src_loc_at=1 + src_loc_at) |
| 111 | + else: |
| 112 | + fields[field_name] = Signal.like(field, name=concat(new_name, field_name), |
| 113 | + src_loc_at=1 + src_loc_at) |
| 114 | + |
| 115 | + return Record(other.layout, name=new_name, fields=fields, src_loc_at=1) |
| 116 | + |
| 117 | + def __init__(self, layout, *, name=None, fields=None, src_loc_at=0): |
| 118 | + if name is None: |
| 119 | + name = tracer.get_var_name(depth=2 + src_loc_at, default=None) |
| 120 | + |
| 121 | + self.name = name |
| 122 | + self.src_loc = tracer.get_src_loc(src_loc_at) |
| 123 | + |
| 124 | + def concat(a, b): |
| 125 | + if a is None: |
| 126 | + return b |
| 127 | + return "{}__{}".format(a, b) |
| 128 | + |
| 129 | + self.layout = Layout.cast(layout, src_loc_at=1 + src_loc_at) |
| 130 | + self.fields = OrderedDict() |
| 131 | + for field_name, field_shape, field_dir in self.layout: |
| 132 | + if fields is not None and field_name in fields: |
| 133 | + field = fields[field_name] |
| 134 | + if isinstance(field_shape, Layout): |
| 135 | + assert isinstance(field, Record) and field_shape == field.layout |
| 136 | + else: |
| 137 | + assert isinstance(field, Signal) and Shape.cast(field_shape) == field.shape() |
| 138 | + self.fields[field_name] = field |
| 139 | + else: |
| 140 | + if isinstance(field_shape, Layout): |
| 141 | + self.fields[field_name] = Record(field_shape, name=concat(name, field_name), |
| 142 | + src_loc_at=1 + src_loc_at) |
| 143 | + else: |
| 144 | + self.fields[field_name] = Signal(field_shape, name=concat(name, field_name), |
| 145 | + src_loc_at=1 + src_loc_at) |
| 146 | + |
| 147 | + def __getattr__(self, name): |
| 148 | + return self[name] |
| 149 | + |
| 150 | + def __getitem__(self, item): |
| 151 | + if isinstance(item, str): |
| 152 | + try: |
| 153 | + return self.fields[item] |
| 154 | + except KeyError: |
| 155 | + if self.name is None: |
| 156 | + reference = "Unnamed record" |
| 157 | + else: |
| 158 | + reference = "Record '{}'".format(self.name) |
| 159 | + raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?" |
| 160 | + .format(reference, item, ", ".join(self.fields))) from None |
| 161 | + elif isinstance(item, tuple): |
| 162 | + return Record(self.layout[item], fields={ |
| 163 | + field_name: field_value |
| 164 | + for field_name, field_value in self.fields.items() |
| 165 | + if field_name in item |
| 166 | + }) |
| 167 | + else: |
| 168 | + try: |
| 169 | + return Value.__getitem__(self, item) |
| 170 | + except KeyError: |
| 171 | + if self.name is None: |
| 172 | + reference = "Unnamed record" |
| 173 | + else: |
| 174 | + reference = "Record '{}'".format(self.name) |
| 175 | + raise AttributeError("{} does not have a field '{}'. Did you mean one of: {}?" |
| 176 | + .format(reference, item, ", ".join(self.fields))) from None |
| 177 | + |
| 178 | + @ValueCastable.lowermethod |
| 179 | + def as_value(self): |
| 180 | + return Cat(self.fields.values()) |
| 181 | + |
| 182 | + def __len__(self): |
| 183 | + return len(self.as_value()) |
| 184 | + |
| 185 | + def _lhs_signals(self): |
| 186 | + return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet()) |
| 187 | + |
| 188 | + def _rhs_signals(self): |
| 189 | + return union((f._rhs_signals() for f in self.fields.values()), start=SignalSet()) |
| 190 | + |
| 191 | + def __repr__(self): |
| 192 | + fields = [] |
| 193 | + for field_name, field in self.fields.items(): |
| 194 | + if isinstance(field, Signal): |
| 195 | + fields.append(field_name) |
| 196 | + else: |
| 197 | + fields.append(repr(field)) |
| 198 | + name = self.name |
| 199 | + if name is None: |
| 200 | + name = "<unnamed>" |
| 201 | + return "(rec {} {})".format(name, " ".join(fields)) |
| 202 | + |
| 203 | + def shape(self): |
| 204 | + return self.as_value().shape() |
| 205 | + |
| 206 | + def connect(self, *subordinates, include=None, exclude=None): |
| 207 | + def rec_name(record): |
| 208 | + if record.name is None: |
| 209 | + return "unnamed record" |
| 210 | + else: |
| 211 | + return "record '{}'".format(record.name) |
| 212 | + |
| 213 | + for field in include or {}: |
| 214 | + if field not in self.fields: |
| 215 | + raise AttributeError("Cannot include field '{}' because it is not present in {}" |
| 216 | + .format(field, rec_name(self))) |
| 217 | + for field in exclude or {}: |
| 218 | + if field not in self.fields: |
| 219 | + raise AttributeError("Cannot exclude field '{}' because it is not present in {}" |
| 220 | + .format(field, rec_name(self))) |
| 221 | + |
| 222 | + stmts = [] |
| 223 | + for field in self.fields: |
| 224 | + if include is not None and field not in include: |
| 225 | + continue |
| 226 | + if exclude is not None and field in exclude: |
| 227 | + continue |
| 228 | + |
| 229 | + shape, direction = self.layout[field] |
| 230 | + if not isinstance(shape, Layout) and direction == DIR_NONE: |
| 231 | + raise TypeError("Cannot connect field '{}' of {} because it does not have " |
| 232 | + "a direction" |
| 233 | + .format(field, rec_name(self))) |
| 234 | + |
| 235 | + item = self.fields[field] |
| 236 | + subord_items = [] |
| 237 | + for subord in subordinates: |
| 238 | + if field not in subord.fields: |
| 239 | + raise AttributeError("Cannot connect field '{}' of {} to subordinate {} " |
| 240 | + "because the subordinate record does not have this field" |
| 241 | + .format(field, rec_name(self), rec_name(subord))) |
| 242 | + subord_items.append(subord.fields[field]) |
| 243 | + |
| 244 | + if isinstance(shape, Layout): |
| 245 | + sub_include = include[field] if include and field in include else None |
| 246 | + sub_exclude = exclude[field] if exclude and field in exclude else None |
| 247 | + stmts += item.connect(*subord_items, include=sub_include, exclude=sub_exclude) |
| 248 | + else: |
| 249 | + if direction == DIR_FANOUT: |
| 250 | + stmts += [sub_item.eq(item) for sub_item in subord_items] |
| 251 | + if direction == DIR_FANIN: |
| 252 | + stmts += [item.eq(reduce(lambda a, b: a | b, subord_items))] |
| 253 | + |
| 254 | + return stmts |
| 255 | + |
| 256 | +def _valueproxy(name): |
| 257 | + value_func = getattr(Value, name) |
| 258 | + @wraps(value_func) |
| 259 | + def _wrapper(self, *args, **kwargs): |
| 260 | + return value_func(Value.cast(self), *args, **kwargs) |
| 261 | + return _wrapper |
| 262 | + |
| 263 | +for name in [ |
| 264 | + "__bool__", |
| 265 | + "__invert__", "__neg__", |
| 266 | + "__add__", "__radd__", "__sub__", "__rsub__", |
| 267 | + "__mul__", "__rmul__", |
| 268 | + "__mod__", "__rmod__", "__floordiv__", "__rfloordiv__", |
| 269 | + "__lshift__", "__rlshift__", "__rshift__", "__rrshift__", |
| 270 | + "__and__", "__rand__", "__xor__", "__rxor__", "__or__", "__ror__", |
| 271 | + "__eq__", "__ne__", "__lt__", "__le__", "__gt__", "__ge__", |
| 272 | + "__abs__", "__len__", |
| 273 | + "as_unsigned", "as_signed", "bool", "any", "all", "xor", "implies", |
| 274 | + "bit_select", "word_select", "matches", |
| 275 | + "shift_left", "shift_right", "rotate_left", "rotate_right", "eq" |
| 276 | + ]: |
| 277 | + setattr(Record, name, _valueproxy(name)) |
| 278 | + |
| 279 | +del _valueproxy |
| 280 | +del name |
0 commit comments