|
17 | 17 | "TransformedElaboratable",
|
18 | 18 | "DomainCollector", "DomainRenamer", "DomainLowerer",
|
19 | 19 | "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
|
20 |
| - "ResetInserter", "EnableInserter"] |
| 20 | + "ResetInserter", "EnableInserter", "AssignmentLegalizer"] |
21 | 21 |
|
22 | 22 |
|
23 | 23 | class ValueVisitor(metaclass=ABCMeta):
|
@@ -670,3 +670,85 @@ def on_fragment(self, fragment):
|
670 | 670 | if port._domain in self.controls:
|
671 | 671 | port._en = Mux(self.controls[port._domain], port._en, Const(0, len(port._en)))
|
672 | 672 | return new_fragment
|
| 673 | + |
| 674 | + |
| 675 | +class AssignmentLegalizer(FragmentTransformer, StatementTransformer): |
| 676 | + """Ensures all assignments in switches have one of the following on the LHS: |
| 677 | +
|
| 678 | + - a `Signal` |
| 679 | + - a `Slice` with `value` that is a `Signal` |
| 680 | + """ |
| 681 | + def emit_assign(self, lhs, rhs, lhs_start=0, lhs_stop=None): |
| 682 | + if isinstance(lhs, ArrayProxy): |
| 683 | + # Lower into a switch. |
| 684 | + cases = {} |
| 685 | + for idx, val in enumerate(lhs.elems): |
| 686 | + cases[idx] = self.emit_assign(val, rhs, lhs_start, lhs_stop) |
| 687 | + return [Switch(lhs.index, cases)] |
| 688 | + elif isinstance(lhs, Part): |
| 689 | + offset = lhs.offset |
| 690 | + width = lhs.width |
| 691 | + if lhs_start != 0: |
| 692 | + width -= lhs_start |
| 693 | + if lhs_stop is not None: |
| 694 | + width = lhs_stop - lhs_start |
| 695 | + cases = {} |
| 696 | + lhs_width = len(lhs.value) |
| 697 | + for idx in range(lhs_width): |
| 698 | + start = lhs_start + idx * lhs.stride |
| 699 | + if start >= lhs_width: |
| 700 | + break |
| 701 | + stop = min(start + width, lhs_width) |
| 702 | + cases[idx] = self.emit_assign(lhs.value, rhs, start, stop) |
| 703 | + return [Switch(offset, cases)] |
| 704 | + elif isinstance(lhs, Slice): |
| 705 | + part_start = lhs_start + lhs.start |
| 706 | + if lhs_stop is not None: |
| 707 | + part_stop = lhs_stop + lhs.start |
| 708 | + else: |
| 709 | + part_stop = lhs_start + lhs.stop |
| 710 | + return self.emit_assign(lhs.value, rhs, part_start, part_stop) |
| 711 | + elif isinstance(lhs, Cat): |
| 712 | + # Split into several assignments. |
| 713 | + part_stop = 0 |
| 714 | + res = [] |
| 715 | + if lhs_stop is None: |
| 716 | + lhs_len = len(lhs) - lhs_start |
| 717 | + else: |
| 718 | + lhs_len = lhs_stop - lhs_start |
| 719 | + if len(rhs) < lhs_len: |
| 720 | + rhs |= Const(0, Shape(lhs_len, signed=rhs.shape().signed)) |
| 721 | + for val in lhs.parts: |
| 722 | + part_start = part_stop |
| 723 | + part_len = len(val) |
| 724 | + part_stop = part_start + part_len |
| 725 | + if lhs_start >= part_stop: |
| 726 | + continue |
| 727 | + if lhs_start < part_start: |
| 728 | + part_lhs_start = 0 |
| 729 | + part_rhs_start = part_start - lhs_start |
| 730 | + else: |
| 731 | + part_lhs_start = lhs_start - part_start |
| 732 | + part_rhs_start = 0 |
| 733 | + if lhs_stop is not None and lhs_stop <= part_start: |
| 734 | + continue |
| 735 | + elif lhs_stop is None or lhs_stop >= part_stop: |
| 736 | + part_lhs_stop = None |
| 737 | + else: |
| 738 | + part_lhs_stop = lhs_stop - part_start |
| 739 | + res += self.emit_assign(val, rhs[part_rhs_start:], part_lhs_start, part_lhs_stop) |
| 740 | + return res |
| 741 | + elif isinstance(lhs, Signal): |
| 742 | + # Already ok. |
| 743 | + if lhs_start != 0 or lhs_stop is not None: |
| 744 | + return [Assign(lhs[lhs_start:lhs_stop], rhs)] |
| 745 | + else: |
| 746 | + return [Assign(lhs, rhs)] |
| 747 | + elif isinstance(lhs, Operator): |
| 748 | + assert lhs.operator in ('u', 's') |
| 749 | + return self.emit_assign(lhs.operands[0], rhs, lhs_start, lhs_stop) |
| 750 | + else: |
| 751 | + raise TypeError |
| 752 | + |
| 753 | + def on_Assign(self, stmt): |
| 754 | + return self.emit_assign(stmt.lhs, stmt.rhs) |
0 commit comments