Skip to content

Commit 930a379

Browse files
authored
[mypyc] Add is_bool_or_bit_rprimitive (#19406)
Added a wrapper to check if a type is either a bool or bit primitive as these two checks are often done together. The wrapper should help in preventing suboptimal code generation if one forgets to check for the bit primitive in cases when it can be trivially expanded to bool. One such case was in translation of binary ops, which is fixed in this PR. Example code: ``` def f(a: float, b: float, c: float) -> bool: return (a == b) & (a == c) ``` IR before: ``` def f(a, b, c): a, b, c :: float r0, r1 :: bit r2 :: bool r3 :: int r4 :: bool r5, r6 :: int r7 :: object r8, r9 :: bool L0: r0 = a == b r1 = a == c r2 = r0 << 1 r3 = extend r2: builtins.bool to builtins.int r4 = r1 << 1 r5 = extend r4: builtins.bool to builtins.int r6 = CPyTagged_And(r3, r5) dec_ref r3 :: int dec_ref r5 :: int r7 = box(int, r6) r8 = unbox(bool, r7) dec_ref r7 if is_error(r8) goto L2 (error at f:2) else goto L1 L1: return r8 L2: r9 = <error> :: bool return r9 ``` IR after: ``` def f(a, b, c): a, b, c :: float r0, r1 :: bit r2 :: bool L0: r0 = a == b r1 = a == c r2 = r0 & r1 return r2 ```
1 parent 10f95e6 commit 930a379

File tree

5 files changed

+78
-31
lines changed

5 files changed

+78
-31
lines changed

mypyc/codegen/emit.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@
2828
RType,
2929
RUnion,
3030
int_rprimitive,
31-
is_bit_rprimitive,
32-
is_bool_rprimitive,
31+
is_bool_or_bit_rprimitive,
3332
is_bytes_rprimitive,
3433
is_dict_rprimitive,
3534
is_fixed_width_rtype,
@@ -615,8 +614,7 @@ def emit_cast(
615614
or is_range_rprimitive(typ)
616615
or is_float_rprimitive(typ)
617616
or is_int_rprimitive(typ)
618-
or is_bool_rprimitive(typ)
619-
or is_bit_rprimitive(typ)
617+
or is_bool_or_bit_rprimitive(typ)
620618
or is_fixed_width_rtype(typ)
621619
):
622620
if declare_dest:
@@ -638,7 +636,7 @@ def emit_cast(
638636
elif is_int_rprimitive(typ) or is_fixed_width_rtype(typ):
639637
# TODO: Range check for fixed-width types?
640638
prefix = "PyLong"
641-
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
639+
elif is_bool_or_bit_rprimitive(typ):
642640
prefix = "PyBool"
643641
else:
644642
assert False, f"unexpected primitive type: {typ}"
@@ -889,7 +887,7 @@ def emit_unbox(
889887
self.emit_line("else {")
890888
self.emit_line(failure)
891889
self.emit_line("}")
892-
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
890+
elif is_bool_or_bit_rprimitive(typ):
893891
# Whether we are borrowing or not makes no difference.
894892
if declare_dest:
895893
self.emit_line(f"char {dest};")
@@ -1015,7 +1013,7 @@ def emit_box(
10151013
if is_int_rprimitive(typ) or is_short_int_rprimitive(typ):
10161014
# Steal the existing reference if it exists.
10171015
self.emit_line(f"{declaration}{dest} = CPyTagged_StealAsObject({src});")
1018-
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
1016+
elif is_bool_or_bit_rprimitive(typ):
10191017
# N.B: bool is special cased to produce a borrowed value
10201018
# after boxing, so we don't need to increment the refcount
10211019
# when this comes directly from a Box op.

mypyc/ir/ops.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ class to enable the new behavior. Sometimes adding a new abstract
4242
cstring_rprimitive,
4343
float_rprimitive,
4444
int_rprimitive,
45-
is_bit_rprimitive,
46-
is_bool_rprimitive,
45+
is_bool_or_bit_rprimitive,
4746
is_int_rprimitive,
4847
is_none_rprimitive,
4948
is_pointer_rprimitive,
@@ -1089,11 +1088,7 @@ def __init__(self, src: Value, line: int = -1) -> None:
10891088
self.src = src
10901089
self.type = object_rprimitive
10911090
# When we box None and bool values, we produce a borrowed result
1092-
if (
1093-
is_none_rprimitive(self.src.type)
1094-
or is_bool_rprimitive(self.src.type)
1095-
or is_bit_rprimitive(self.src.type)
1096-
):
1091+
if is_none_rprimitive(self.src.type) or is_bool_or_bit_rprimitive(self.src.type):
10971092
self.is_borrowed = True
10981093

10991094
def sources(self) -> list[Value]:

mypyc/ir/rtypes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,10 @@ def is_bit_rprimitive(rtype: RType) -> bool:
582582
return isinstance(rtype, RPrimitive) and rtype.name == "bit"
583583

584584

585+
def is_bool_or_bit_rprimitive(rtype: RType) -> bool:
586+
return is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype)
587+
588+
585589
def is_object_rprimitive(rtype: RType) -> bool:
586590
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.object"
587591

mypyc/irbuild/ll_builder.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,7 @@
9393
dict_rprimitive,
9494
float_rprimitive,
9595
int_rprimitive,
96-
is_bit_rprimitive,
97-
is_bool_rprimitive,
96+
is_bool_or_bit_rprimitive,
9897
is_bytes_rprimitive,
9998
is_dict_rprimitive,
10099
is_fixed_width_rtype,
@@ -381,16 +380,12 @@ def coerce(
381380
):
382381
# Equivalent types
383382
return src
384-
elif (is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)) and is_tagged(
385-
target_type
386-
):
383+
elif is_bool_or_bit_rprimitive(src_type) and is_tagged(target_type):
387384
shifted = self.int_op(
388385
bool_rprimitive, src, Integer(1, bool_rprimitive), IntOp.LEFT_SHIFT
389386
)
390387
return self.add(Extend(shifted, target_type, signed=False))
391-
elif (
392-
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
393-
) and is_fixed_width_rtype(target_type):
388+
elif is_bool_or_bit_rprimitive(src_type) and is_fixed_width_rtype(target_type):
394389
return self.add(Extend(src, target_type, signed=False))
395390
elif isinstance(src, Integer) and is_float_rprimitive(target_type):
396391
if is_tagged(src_type):
@@ -1341,7 +1336,11 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13411336
return self.compare_strings(lreg, rreg, op, line)
13421337
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="):
13431338
return self.compare_bytes(lreg, rreg, op, line)
1344-
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
1339+
if (
1340+
is_bool_or_bit_rprimitive(ltype)
1341+
and is_bool_or_bit_rprimitive(rtype)
1342+
and op in BOOL_BINARY_OPS
1343+
):
13451344
if op in ComparisonOp.signed_ops:
13461345
return self.bool_comparison_op(lreg, rreg, op, line)
13471346
else:
@@ -1355,7 +1354,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13551354
op_id = int_op_to_id[op]
13561355
else:
13571356
op_id = IntOp.DIV
1358-
if is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
1357+
if is_bool_or_bit_rprimitive(rtype):
13591358
rreg = self.coerce(rreg, ltype, line)
13601359
rtype = ltype
13611360
if is_fixed_width_rtype(rtype) or is_tagged(rtype):
@@ -1367,7 +1366,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13671366
elif op in ComparisonOp.signed_ops:
13681367
if is_int_rprimitive(rtype):
13691368
rreg = self.coerce_int_to_fixed_width(rreg, ltype, line)
1370-
elif is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
1369+
elif is_bool_or_bit_rprimitive(rtype):
13711370
rreg = self.coerce(rreg, ltype, line)
13721371
op_id = ComparisonOp.signed_ops[op]
13731372
if is_fixed_width_rtype(rreg.type):
@@ -1387,13 +1386,13 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
13871386
)
13881387
if is_tagged(ltype):
13891388
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
1390-
if is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
1389+
if is_bool_or_bit_rprimitive(ltype):
13911390
lreg = self.coerce(lreg, rtype, line)
13921391
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
13931392
elif op in ComparisonOp.signed_ops:
13941393
if is_int_rprimitive(ltype):
13951394
lreg = self.coerce_int_to_fixed_width(lreg, rtype, line)
1396-
elif is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
1395+
elif is_bool_or_bit_rprimitive(ltype):
13971396
lreg = self.coerce(lreg, rtype, line)
13981397
op_id = ComparisonOp.signed_ops[op]
13991398
if isinstance(lreg, Integer):
@@ -1544,7 +1543,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
15441543
compare = self.binary_op(lhs_item, rhs_item, op, line)
15451544
# Cast to bool if necessary since most types uses comparison returning a object type
15461545
# See generic_ops.py for more information
1547-
if not (is_bool_rprimitive(compare.type) or is_bit_rprimitive(compare.type)):
1546+
if not is_bool_or_bit_rprimitive(compare.type):
15481547
compare = self.primitive_op(bool_op, [compare], line)
15491548
if i < len(lhs.type.types) - 1:
15501549
branch = Branch(compare, early_stop, check_blocks[i + 1], Branch.BOOL)
@@ -1563,7 +1562,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
15631562

15641563
def translate_instance_contains(self, inst: Value, item: Value, op: str, line: int) -> Value:
15651564
res = self.gen_method_call(inst, "__contains__", [item], None, line)
1566-
if not is_bool_rprimitive(res.type):
1565+
if not is_bool_or_bit_rprimitive(res.type):
15671566
res = self.primitive_op(bool_op, [res], line)
15681567
if op == "not in":
15691568
res = self.bool_bitwise_op(res, Integer(1, rtype=bool_rprimitive), "^", line)
@@ -1590,7 +1589,7 @@ def unary_not(self, value: Value, line: int) -> Value:
15901589

15911590
def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
15921591
typ = value.type
1593-
if is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
1592+
if is_bool_or_bit_rprimitive(typ):
15941593
if expr_op == "not":
15951594
return self.unary_not(value, line)
15961595
if expr_op == "+":
@@ -1748,7 +1747,7 @@ def bool_value(self, value: Value) -> Value:
17481747
17491748
The result type can be bit_rprimitive or bool_rprimitive.
17501749
"""
1751-
if is_bool_rprimitive(value.type) or is_bit_rprimitive(value.type):
1750+
if is_bool_or_bit_rprimitive(value.type):
17521751
result = value
17531752
elif is_runtime_subtype(value.type, int_rprimitive):
17541753
zero = Integer(0, short_int_rprimitive)

mypyc/test-data/irbuild-bool.test

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,54 @@ L0:
422422
r1 = extend r0: builtins.bool to builtins.int
423423
x = r1
424424
return x
425+
426+
[case testBitToBoolPromotion]
427+
def bitand(x: float, y: float, z: float) -> bool:
428+
b = (x == y) & (x == z)
429+
return b
430+
def bitor(x: float, y: float, z: float) -> bool:
431+
b = (x == y) | (x == z)
432+
return b
433+
def bitxor(x: float, y: float, z: float) -> bool:
434+
b = (x == y) ^ (x == z)
435+
return b
436+
def invert(x: float, y: float) -> bool:
437+
return not(x == y)
438+
[out]
439+
def bitand(x, y, z):
440+
x, y, z :: float
441+
r0, r1 :: bit
442+
r2, b :: bool
443+
L0:
444+
r0 = x == y
445+
r1 = x == z
446+
r2 = r0 & r1
447+
b = r2
448+
return b
449+
def bitor(x, y, z):
450+
x, y, z :: float
451+
r0, r1 :: bit
452+
r2, b :: bool
453+
L0:
454+
r0 = x == y
455+
r1 = x == z
456+
r2 = r0 | r1
457+
b = r2
458+
return b
459+
def bitxor(x, y, z):
460+
x, y, z :: float
461+
r0, r1 :: bit
462+
r2, b :: bool
463+
L0:
464+
r0 = x == y
465+
r1 = x == z
466+
r2 = r0 ^ r1
467+
b = r2
468+
return b
469+
def invert(x, y):
470+
x, y :: float
471+
r0, r1 :: bit
472+
L0:
473+
r0 = x == y
474+
r1 = r0 ^ 1
475+
return r1

0 commit comments

Comments
 (0)