Skip to content

[mypyc] Add is_bool_or_bit_rprimitive #19406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions mypyc/codegen/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
RType,
RUnion,
int_rprimitive,
is_bit_rprimitive,
is_bool_rprimitive,
is_bool_or_bit_rprimitive,
is_bytes_rprimitive,
is_dict_rprimitive,
is_fixed_width_rtype,
Expand Down Expand Up @@ -615,8 +614,7 @@ def emit_cast(
or is_range_rprimitive(typ)
or is_float_rprimitive(typ)
or is_int_rprimitive(typ)
or is_bool_rprimitive(typ)
or is_bit_rprimitive(typ)
or is_bool_or_bit_rprimitive(typ)
or is_fixed_width_rtype(typ)
):
if declare_dest:
Expand All @@ -638,7 +636,7 @@ def emit_cast(
elif is_int_rprimitive(typ) or is_fixed_width_rtype(typ):
# TODO: Range check for fixed-width types?
prefix = "PyLong"
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
elif is_bool_or_bit_rprimitive(typ):
prefix = "PyBool"
else:
assert False, f"unexpected primitive type: {typ}"
Expand Down Expand Up @@ -889,7 +887,7 @@ def emit_unbox(
self.emit_line("else {")
self.emit_line(failure)
self.emit_line("}")
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
elif is_bool_or_bit_rprimitive(typ):
# Whether we are borrowing or not makes no difference.
if declare_dest:
self.emit_line(f"char {dest};")
Expand Down Expand Up @@ -1015,7 +1013,7 @@ def emit_box(
if is_int_rprimitive(typ) or is_short_int_rprimitive(typ):
# Steal the existing reference if it exists.
self.emit_line(f"{declaration}{dest} = CPyTagged_StealAsObject({src});")
elif is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
elif is_bool_or_bit_rprimitive(typ):
# N.B: bool is special cased to produce a borrowed value
# after boxing, so we don't need to increment the refcount
# when this comes directly from a Box op.
Expand Down
9 changes: 2 additions & 7 deletions mypyc/ir/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ class to enable the new behavior. Sometimes adding a new abstract
cstring_rprimitive,
float_rprimitive,
int_rprimitive,
is_bit_rprimitive,
is_bool_rprimitive,
is_bool_or_bit_rprimitive,
is_int_rprimitive,
is_none_rprimitive,
is_pointer_rprimitive,
Expand Down Expand Up @@ -1089,11 +1088,7 @@ def __init__(self, src: Value, line: int = -1) -> None:
self.src = src
self.type = object_rprimitive
# When we box None and bool values, we produce a borrowed result
if (
is_none_rprimitive(self.src.type)
or is_bool_rprimitive(self.src.type)
or is_bit_rprimitive(self.src.type)
):
if is_none_rprimitive(self.src.type) or is_bool_or_bit_rprimitive(self.src.type):
self.is_borrowed = True

def sources(self) -> list[Value]:
Expand Down
4 changes: 4 additions & 0 deletions mypyc/ir/rtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,10 @@ def is_bit_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == "bit"


def is_bool_or_bit_rprimitive(rtype: RType) -> bool:
return is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype)


def is_object_rprimitive(rtype: RType) -> bool:
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.object"

Expand Down
33 changes: 16 additions & 17 deletions mypyc/irbuild/ll_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@
dict_rprimitive,
float_rprimitive,
int_rprimitive,
is_bit_rprimitive,
is_bool_rprimitive,
is_bool_or_bit_rprimitive,
is_bytes_rprimitive,
is_dict_rprimitive,
is_fixed_width_rtype,
Expand Down Expand Up @@ -376,16 +375,12 @@ def coerce(
):
# Equivalent types
return src
elif (is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)) and is_tagged(
target_type
):
elif is_bool_or_bit_rprimitive(src_type) and is_tagged(target_type):
shifted = self.int_op(
bool_rprimitive, src, Integer(1, bool_rprimitive), IntOp.LEFT_SHIFT
)
return self.add(Extend(shifted, target_type, signed=False))
elif (
is_bool_rprimitive(src_type) or is_bit_rprimitive(src_type)
) and is_fixed_width_rtype(target_type):
elif is_bool_or_bit_rprimitive(src_type) and is_fixed_width_rtype(target_type):
return self.add(Extend(src, target_type, signed=False))
elif isinstance(src, Integer) and is_float_rprimitive(target_type):
if is_tagged(src_type):
Expand Down Expand Up @@ -1336,7 +1331,11 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
return self.compare_strings(lreg, rreg, op, line)
if is_bytes_rprimitive(ltype) and is_bytes_rprimitive(rtype) and op in ("==", "!="):
return self.compare_bytes(lreg, rreg, op, line)
if is_bool_rprimitive(ltype) and is_bool_rprimitive(rtype) and op in BOOL_BINARY_OPS:
if (
is_bool_or_bit_rprimitive(ltype)
and is_bool_or_bit_rprimitive(rtype)
and op in BOOL_BINARY_OPS
):
if op in ComparisonOp.signed_ops:
return self.bool_comparison_op(lreg, rreg, op, line)
else:
Expand All @@ -1350,7 +1349,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
op_id = int_op_to_id[op]
else:
op_id = IntOp.DIV
if is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
if is_bool_or_bit_rprimitive(rtype):
rreg = self.coerce(rreg, ltype, line)
rtype = ltype
if is_fixed_width_rtype(rtype) or is_tagged(rtype):
Expand All @@ -1362,7 +1361,7 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
elif op in ComparisonOp.signed_ops:
if is_int_rprimitive(rtype):
rreg = self.coerce_int_to_fixed_width(rreg, ltype, line)
elif is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype):
elif is_bool_or_bit_rprimitive(rtype):
rreg = self.coerce(rreg, ltype, line)
op_id = ComparisonOp.signed_ops[op]
if is_fixed_width_rtype(rreg.type):
Expand All @@ -1382,13 +1381,13 @@ def binary_op(self, lreg: Value, rreg: Value, op: str, line: int) -> Value:
)
if is_tagged(ltype):
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
if is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
if is_bool_or_bit_rprimitive(ltype):
lreg = self.coerce(lreg, rtype, line)
return self.fixed_width_int_op(rtype, lreg, rreg, op_id, line)
elif op in ComparisonOp.signed_ops:
if is_int_rprimitive(ltype):
lreg = self.coerce_int_to_fixed_width(lreg, rtype, line)
elif is_bool_rprimitive(ltype) or is_bit_rprimitive(ltype):
elif is_bool_or_bit_rprimitive(ltype):
lreg = self.coerce(lreg, rtype, line)
op_id = ComparisonOp.signed_ops[op]
if isinstance(lreg, Integer):
Expand Down Expand Up @@ -1534,7 +1533,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val
compare = self.binary_op(lhs_item, rhs_item, op, line)
# Cast to bool if necessary since most types uses comparison returning a object type
# See generic_ops.py for more information
if not (is_bool_rprimitive(compare.type) or is_bit_rprimitive(compare.type)):
if not is_bool_or_bit_rprimitive(compare.type):
compare = self.primitive_op(bool_op, [compare], line)
if i < len(lhs.type.types) - 1:
branch = Branch(compare, early_stop, check_blocks[i + 1], Branch.BOOL)
Expand All @@ -1553,7 +1552,7 @@ def compare_tuples(self, lhs: Value, rhs: Value, op: str, line: int = -1) -> Val

def translate_instance_contains(self, inst: Value, item: Value, op: str, line: int) -> Value:
res = self.gen_method_call(inst, "__contains__", [item], None, line)
if not is_bool_rprimitive(res.type):
if not is_bool_or_bit_rprimitive(res.type):
res = self.primitive_op(bool_op, [res], line)
if op == "not in":
res = self.bool_bitwise_op(res, Integer(1, rtype=bool_rprimitive), "^", line)
Expand All @@ -1580,7 +1579,7 @@ def unary_not(self, value: Value, line: int) -> Value:

def unary_op(self, value: Value, expr_op: str, line: int) -> Value:
typ = value.type
if is_bool_rprimitive(typ) or is_bit_rprimitive(typ):
if is_bool_or_bit_rprimitive(typ):
if expr_op == "not":
return self.unary_not(value, line)
if expr_op == "+":
Expand Down Expand Up @@ -1738,7 +1737,7 @@ def bool_value(self, value: Value) -> Value:

The result type can be bit_rprimitive or bool_rprimitive.
"""
if is_bool_rprimitive(value.type) or is_bit_rprimitive(value.type):
if is_bool_or_bit_rprimitive(value.type):
result = value
elif is_runtime_subtype(value.type, int_rprimitive):
zero = Integer(0, short_int_rprimitive)
Expand Down
51 changes: 51 additions & 0 deletions mypyc/test-data/irbuild-bool.test
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,54 @@ L0:
r1 = extend r0: builtins.bool to builtins.int
x = r1
return x

[case testBitToBoolPromotion]
def bitand(x: float, y: float, z: float) -> bool:
b = (x == y) & (x == z)
return b
def bitor(x: float, y: float, z: float) -> bool:
b = (x == y) | (x == z)
return b
def bitxor(x: float, y: float, z: float) -> bool:
b = (x == y) ^ (x == z)
return b
def invert(x: float, y: float) -> bool:
return not(x == y)
[out]
def bitand(x, y, z):
x, y, z :: float
r0, r1 :: bit
r2, b :: bool
L0:
r0 = x == y
r1 = x == z
r2 = r0 & r1
b = r2
return b
def bitor(x, y, z):
x, y, z :: float
r0, r1 :: bit
r2, b :: bool
L0:
r0 = x == y
r1 = x == z
r2 = r0 | r1
b = r2
return b
def bitxor(x, y, z):
x, y, z :: float
r0, r1 :: bit
r2, b :: bool
L0:
r0 = x == y
r1 = x == z
r2 = r0 ^ r1
b = r2
return b
def invert(x, y):
x, y :: float
r0, r1 :: bit
L0:
r0 = x == y
r1 = r0 ^ 1
return r1