@@ -1371,6 +1371,7 @@ def connect(m, *args, **kwargs):
1371
1371
* For a given path, if any of the interface objects has an input port member corresponding
1372
1372
to a constant value, then the rest of the interface objects must have output port members
1373
1373
corresponding to the same constant value.
1374
+ * When connecting multiple interface objects, at least one connection must be made.
1374
1375
1375
1376
For example, if :py:`obj1` is being connected to :py:`obj2` and :py:`obj3`, and :py:`obj1.a.b`
1376
1377
is an output, then :py:`obj2.a.b` and :py:`obj2.a.b` must exist and be inputs. If :py:`obj2.c`
@@ -1420,10 +1421,15 @@ def connect(m, *args, **kwargs):
1420
1421
reasons_as_string )
1421
1422
signatures [handle ] = obj .signature
1422
1423
1423
- # Collate signatures and build connections.
1424
+ # Connecting zero or one signatures is OK.
1425
+ if len (signatures ) <= 1 :
1426
+ return
1427
+
1428
+ # Collate signatures, build connections, track whether we see any input or output.
1424
1429
flattens = {handle : iter (sorted (signature .members .flatten ()))
1425
1430
for handle , signature in signatures .items ()}
1426
1431
connections = []
1432
+ any_in , any_out = False , False
1427
1433
# Each iteration of the outer loop is intended to connect several (usually a pair) members
1428
1434
# to each other, e.g. an out member `[0].a` to an in member `[1].a`. However, because we
1429
1435
# do not just check signatures for equality (in order to improve diagnostics), it is possible
@@ -1437,6 +1443,7 @@ def connect(m, *args, **kwargs):
1437
1443
# implied in the flow of each port member, so the signature members are only classified
1438
1444
# here to ensure they are not connected to port members.
1439
1445
is_first = True
1446
+ first_path = None
1440
1447
sig_kind , out_kind , in_kind = [], [], []
1441
1448
for handle , flattened_members in flattens .items ():
1442
1449
path_for_handle , member = next (flattened_members , (None , None ))
@@ -1499,6 +1506,8 @@ def connect(m, *args, **kwargs):
1499
1506
# There are no port members at this point; we're done with this path.
1500
1507
continue
1501
1508
# There are only port members after this point.
1509
+ any_in = any_in or bool (in_kind )
1510
+ any_out = any_out or bool (out_kind )
1502
1511
is_first = True
1503
1512
for (path , member ) in in_kind + out_kind :
1504
1513
member_shape = member .shape
@@ -1574,6 +1583,14 @@ def connect_dimensions(dimensions, *, out_path, in_path):
1574
1583
out_path = (* out_path , index ), in_path = (* in_path , index ))
1575
1584
assert out_member .dimensions == in_member .dimensions
1576
1585
connect_dimensions (out_member .dimensions , out_path = out_path , in_path = in_path )
1586
+
1587
+ # If no connections were made, and there were inputs but no outputs in the
1588
+ # signatures, issue a diagnostic as this is most likely in error.
1589
+ if len (connections ) == 0 and any_in and not any_out :
1590
+ raise ConnectionError (f"Only input to input connections have been made between several "
1591
+ f"interfaces; should one of them have been flipped?" )
1592
+
1593
+
1577
1594
# Now that we know all of the connections are legal, add them to the module. This is done
1578
1595
# instead of returning them because adding them to a non-comb domain would subtly violate
1579
1596
# assumptions that `connect()` is intended to provide.
0 commit comments