Skip to content
This repository was archived by the owner on Jun 10, 2024. It is now read-only.

Commit cca9c6f

Browse files
committed
fix: incorrect lowering of bigint constants
1 parent 72c1d31 commit cca9c6f

File tree

7 files changed

+128
-31
lines changed

7 files changed

+128
-31
lines changed

compiler/codegen/src/passes/ssa_to_mlir/builder/function.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,14 @@ impl<'m> ModuleBuilder<'m> {
222222
}
223223
ConstantItem::Integer(Integer::Big(i)) => {
224224
let builder = CirBuilder::new(&self.builder);
225-
let ty = builder.get_cir_bigint_type().base();
225+
let ty = builder
226+
.get_cir_box_type(builder.get_cir_bigint_type())
227+
.base();
226228
let op = builder.build_constant(loc, ty, builder.get_bigint_attr(i, ty));
229+
// We need a cast to generic integer type because lowering of boxed types
230+
// like bigint is to a pointer type,
231+
//let value = op.get_result(0).base();
232+
//let op = builder.build_cast(loc, value, builder.get_cir_integer_type());
227233
op.get_result(0).base()
228234
}
229235
ConstantItem::Float(f) => {

compiler/mlir/c_src/include/CIR-c/BigIntRef.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
#include <cstdlib>
44

55
#ifdef __cplusplus
6-
#include <llvm/ADT/ArrayRef.h>
6+
namespace llvm {
7+
class StringRef;
8+
}
79

810
namespace mlir {
911
namespace cir {
@@ -14,11 +16,11 @@ enum Sign { SignMinus = 0, SignNoSign, SignPlus };
1416

1517
struct BigIntRef {
1618
Sign sign;
17-
const int32_t *digits;
19+
const char *digits;
1820
size_t len;
1921

2022
#ifdef __cplusplus
21-
llvm::ArrayRef<int32_t> data() const;
23+
llvm::StringRef data() const;
2224
#endif
2325
};
2426

compiler/mlir/c_src/include/CIR/Attributes.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def BigIntRefParameter :
6161
}
6262
}
6363
$_printer << "digits = [";
64-
llvm::interleaveComma($_self.data(), $_printer,
65-
[&](int32_t digit) { $_printer << (uint32_t)digit; });
64+
llvm::interleaveComma($_self.data().bytes(), $_printer,
65+
[&](unsigned char digit) { $_printer << digit; });
6666
$_printer << "]";
6767
}];
6868
let comparator = [{ $_lhs.sign == $_rhs.sign && $_lhs.data() == $_rhs.data() }];
@@ -83,11 +83,11 @@ def BigIntAttr : CIR_Attr<"BigInt"> {
8383
let extraClassDeclaration = [{
8484
using ValueType = ::mlir::cir::BigIntRef;
8585
Sign getSign() const { return getValue().sign; }
86-
::llvm::ArrayRef<int32_t> getDigits() const { return getValue().data(); }
86+
::llvm::StringRef getDigits() const { return getValue().data(); }
8787

8888
}];
8989
let skipDefaultBuilders = 1;
90-
let typeBuilder = "CIRBigIntType::get($_type.getContext())";
90+
let typeBuilder = "CIRBoxType::get(CIRBigIntType::get($_type.getContext()))";
9191
}
9292

9393
def BinaryEntrySpecifierParameter :

compiler/mlir/c_src/lib/CIR/Attributes.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ template <> struct FieldParser<AtomRef> {
5151
/// BigIntRef
5252
//===----------------------------------------------------------------------===//
5353

54-
llvm::ArrayRef<int32_t> BigIntRef::data() const { return {digits, len}; }
54+
llvm::StringRef BigIntRef::data() const { return {digits, len}; }
5555

5656
llvm::hash_code mlir::cir::hash_value(const BigIntRef &bigint) {
5757
auto data = bigint.data();
@@ -79,9 +79,9 @@ template <> struct FieldParser<BigIntRef> {
7979
if (kw1 != "digits" || parser.parseEqual())
8080
return failure();
8181

82-
SmallVector<int32_t, 4> digits;
82+
SmallVector<int8_t, 4> digits;
8383
auto parseElt = [&] {
84-
int32_t digit;
84+
int8_t digit;
8585
if (parser.parseInteger(digit))
8686
return failure();
8787
digits.push_back(digit);
@@ -93,9 +93,9 @@ template <> struct FieldParser<BigIntRef> {
9393
return failure();
9494

9595
size_t len = digits.size();
96-
int32_t *data = nullptr;
96+
char *data = nullptr;
9797
if (len > 0) {
98-
data = static_cast<int32_t *>(aligned_alloc(16, digits.size_in_bytes()));
98+
data = static_cast<char *>(aligned_alloc(16, digits.size_in_bytes()));
9999
std::memcpy(data, digits.data(), len);
100100
}
101101
return BigIntRef{sign, data, len};

compiler/mlir/c_src/lib/CIR/ConvertCIRToLLVMPass.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -679,19 +679,19 @@ class CIRConversionPattern : public ConversionPattern {
679679
}
680680

681681
Value createBigIntConstant(OpBuilder &builder, Location loc, Sign sign,
682-
ArrayRef<int32_t> digits, ModuleOp &module) const {
682+
StringRef digits, ModuleOp &module) const {
683683
llvm::SHA1 hasher;
684684
hasher.update((unsigned)sign);
685-
for (auto digit : digits)
686-
hasher.update(digit);
685+
hasher.update(digits);
687686
auto hash = llvm::toHex(hasher.result(), true);
688687
auto globalName = std::string("bigint_") + hash;
689688

689+
auto i8Ty = builder.getI8Type();
690690
auto i32Ty = builder.getI32Type();
691691
auto isizeTy = getIsizeType();
692692
auto termTy = getTermType();
693-
auto digitsTy = LLVM::LLVMArrayType::get(i32Ty, digits.size());
694-
auto emptyArrayTy = LLVM::LLVMArrayType::get(i32Ty, 0);
693+
auto digitsTy = LLVM::LLVMArrayType::get(i8Ty, digits.size());
694+
auto emptyArrayTy = LLVM::LLVMArrayType::get(i8Ty, 0);
695695
auto dataTy = LLVM::LLVMStructType::getLiteral(builder.getContext(),
696696
{i32Ty, digitsTy});
697697
auto genericDataTy =
@@ -715,13 +715,13 @@ class CIRConversionPattern : public ConversionPattern {
715715

716716
Value dataSign = createI32Constant(builder, loc, (unsigned)sign);
717717
Value dataRaw = builder.create<LLVM::ConstantOp>(
718-
loc, digitsTy, builder.getI32ArrayAttr(digits));
718+
loc, digitsTy, builder.getStringAttr(digits.str()));
719719

720720
Value data = builder.create<LLVM::UndefOp>(loc, dataTy);
721721
data = builder.create<LLVM::InsertValueOp>(loc, data, dataSign,
722722
builder.getI64ArrayAttr(0));
723723
data = builder.create<LLVM::InsertValueOp>(loc, data, dataRaw,
724-
builder.getI64ArrayAttr(2));
724+
builder.getI64ArrayAttr(1));
725725
builder.create<LLVM::ReturnOp>(loc, data);
726726
}
727727

@@ -1218,11 +1218,6 @@ struct ConstantOpLowering : public ConvertCIROpToLLVMPattern<cir::ConstantOp> {
12181218
assert(NANBOX_INFINITY != 0);
12191219
return createTermConstant(rewriter, loc, NANBOX_INFINITY);
12201220
})
1221-
.Case<CIRBigIntType>([&](CIRBigIntType) {
1222-
auto bigIntAttr = attr.cast<BigIntAttr>();
1223-
return createBigIntConstant(rewriter, loc, bigIntAttr.getSign(),
1224-
bigIntAttr.getDigits(), module);
1225-
})
12261221
.Case<CIRIntegerType>([&](CIRIntegerType) {
12271222
return createIntegerConstant(rewriter, loc,
12281223
attr.cast<IsizeAttr>().getInt());
@@ -1254,6 +1249,12 @@ struct ConstantOpLowering : public ConvertCIROpToLLVMPattern<cir::ConstantOp> {
12541249
return createBinaryDataConstant(rewriter, loc, str, isUtf8,
12551250
module);
12561251
})
1252+
.Case<CIRBigIntType>([&](CIRBigIntType) {
1253+
auto bigIntAttr = attr.cast<BigIntAttr>();
1254+
return createBigIntConstant(rewriter, loc,
1255+
bigIntAttr.getSign(),
1256+
bigIntAttr.getDigits(), module);
1257+
})
12571258
.Default([](Type) { return nullptr; });
12581259
})
12591260
.Default([](Type) { return nullptr; });
@@ -2927,7 +2928,7 @@ struct BinaryPushOpLowering
29272928
}
29282929

29292930
auto callOp = rewriter.create<LLVM::CallOp>(
2930-
loc, TypeRange({resultTy}), "__lumen_bs_match",
2931+
loc, TypeRange({resultTy}), "__lumen_bs_push",
29312932
ValueRange({bin, specRaw, size, adaptor.value()}));
29322933
Value callResult = callOp->getResult(0);
29332934
Value isErrWide = rewriter.create<LLVM::ExtractValueOp>(

compiler/mlir/src/dialect/cir.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ impl Into<Sign> for BigIntSign {
135135
#[repr(C)]
136136
struct BigIntRaw {
137137
sign: BigIntSign,
138-
digits: *const u32,
138+
digits: *const u8,
139139
num_digits: usize,
140140
}
141141

@@ -146,7 +146,7 @@ pub struct BigIntAttr(AttributeBase);
146146
impl BigIntAttr {
147147
#[inline]
148148
pub fn get(value: &BigInt, ty: TypeBase) -> Self {
149-
let (sign, digits) = value.to_u32_digits();
149+
let (sign, digits) = value.to_bytes_be();
150150
let raw = BigIntRaw {
151151
sign: sign.into(),
152152
digits: digits.as_ptr(),
@@ -159,7 +159,7 @@ impl BigIntAttr {
159159
pub fn value(&self) -> BigInt {
160160
let raw = self.raw();
161161
let digits = unsafe { core::slice::from_raw_parts(raw.digits, raw.num_digits) };
162-
BigInt::from_slice(raw.sign.into(), digits)
162+
BigInt::from_bytes_be(raw.sign.into(), digits)
163163
}
164164

165165
fn raw(&self) -> BigIntRaw {

runtimes/tiny/src/intrinsic/mod.rs

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ impl Into<Sign> for SignRaw {
5454
#[repr(C)]
5555
pub struct BigIntRef {
5656
sign: SignRaw,
57-
digits: [u32],
57+
digits: [u8],
5858
}
5959

6060
/// Allocates a new BigInteger from constant data produced by the compiler
@@ -65,7 +65,7 @@ pub extern "C" fn bigint_from_digits(raw: &BigIntRef) -> OpaqueTerm {
6565
let arc_proc = scheduler.current_process();
6666
let proc = arc_proc.deref();
6767
GcBox::new_in(
68-
BigInteger(BigInt::from_slice(raw.sign.into(), &raw.digits[..])),
68+
BigInteger(BigInt::from_bytes_be(raw.sign.into(), &raw.digits[..])),
6969
proc,
7070
)
7171
.unwrap()
@@ -268,6 +268,87 @@ pub extern "C-unwind" fn bs_init() -> Result<NonNull<BitVec>, ()> {
268268
Ok(unsafe { NonNull::new_unchecked(Box::into_raw(buffer)) })
269269
}
270270

271+
#[allow(improper_ctypes_definitions)]
272+
#[export_name = "__lumen_bs_push"]
273+
pub extern "C-unwind" fn bs_push(
274+
mut bin: NonNull<BitVec>,
275+
spec: BinaryEntrySpecifier,
276+
value: OpaqueTerm,
277+
size: OpaqueTerm,
278+
) -> Result<NonNull<BitVec>, NonNull<ErlangException>> {
279+
let buffer = unsafe { bin.as_mut() };
280+
match spec {
281+
BinaryEntrySpecifier::Integer {
282+
signed,
283+
unit,
284+
endianness,
285+
} => match (value.into(), size.into()) {
286+
(Term::Int(i), Term::Int(size)) => {
287+
if signed {
288+
buffer.push_ap_number(i, (unit as usize) * (size as usize), endianness);
289+
} else {
290+
buffer.push_ap_number(i as u64, (unit as usize) * (size as usize), endianness);
291+
}
292+
Ok(bin)
293+
}
294+
(Term::BigInt(i), Term::Int(size)) => {
295+
buffer.push_ap_bigint(
296+
i.deref(),
297+
(unit as usize) * (size as usize),
298+
signed,
299+
endianness,
300+
);
301+
Ok(bin)
302+
}
303+
_ => Err(badarg(Trace::capture())),
304+
},
305+
BinaryEntrySpecifier::Float { unit, endianness } => match (value.into(), size.into()) {
306+
(Term::Float(f), Term::Int(size)) => match (unit as usize) * (size as usize) {
307+
64 => {
308+
buffer.push_number(f.as_f64(), endianness);
309+
Ok(bin)
310+
}
311+
_ => todo!("bs.push float"),
312+
},
313+
_ => Err(badarg(Trace::capture())),
314+
},
315+
BinaryEntrySpecifier::Binary { unit } => match (value.into(), size.into()) {
316+
(Term::ConstantBinary(bin), Term::Int(size)) => {
317+
let bitsize = (unit as usize) * (size as usize);
318+
todo!()
319+
}
320+
_ => Err(badarg(Trace::capture())),
321+
},
322+
BinaryEntrySpecifier::Utf8 => match (value.into(), size.into()) {
323+
(Term::Int(i), Term::Int(size)) => {
324+
let Ok(codepoint) = i.try_into() else { return Err(badarg(Trace::capture())); };
325+
let Some(c) = char::from_u32(codepoint) else { return Err(badarg(Trace::capture())); };
326+
buffer.push_utf8(c);
327+
Ok(bin)
328+
}
329+
_ => Err(badarg(Trace::capture())),
330+
},
331+
BinaryEntrySpecifier::Utf16 { endianness } => match (value.into(), size.into()) {
332+
(Term::Int(i), Term::Int(size)) => {
333+
let Ok(codepoint) = i.try_into() else { return Err(badarg(Trace::capture())); };
334+
let Some(c) = char::from_u32(codepoint) else { return Err(badarg(Trace::capture())); };
335+
buffer.push_utf16(c, endianness);
336+
Ok(bin)
337+
}
338+
_ => Err(badarg(Trace::capture())),
339+
},
340+
BinaryEntrySpecifier::Utf32 { endianness } => match (value.into(), size.into()) {
341+
(Term::Int(i), Term::Int(size)) => {
342+
let Ok(codepoint) = i.try_into() else { return Err(badarg(Trace::capture())); };
343+
let Some(c) = char::from_u32(codepoint) else { return Err(badarg(Trace::capture())); };
344+
buffer.push_utf32(c, endianness);
345+
Ok(bin)
346+
}
347+
_ => Err(badarg(Trace::capture())),
348+
},
349+
}
350+
}
351+
271352
#[allow(improper_ctypes_definitions)]
272353
#[export_name = "__lumen_bs_finish"]
273354
pub extern "C-unwind" fn bs_finish(buffer: NonNull<BitVec>) -> Result<OpaqueTerm, ()> {
@@ -415,3 +496,10 @@ pub extern "C-unwind" fn bs_match(
415496
pub unsafe extern "C-unwind" fn bs_test_tail(ctx: NonNull<MatchContext>, size: usize) -> bool {
416497
ctx.as_ref().bits_remaining() == size
417498
}
499+
500+
pub(self) fn badarg(trace: Arc<Trace>) -> NonNull<ErlangException> {
501+
crate::erlang::raise2(atoms::Badarg.into(), unsafe {
502+
NonNull::new_unchecked(Trace::into_raw(trace))
503+
})
504+
.unwrap_err()
505+
}

0 commit comments

Comments
 (0)