Skip to content

Use more precise Wasm types #1907

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 70 additions & 2 deletions compiler/lib-wasm/code_generation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,68 @@ let heap_type_sub (ty : W.heap_type) (ty' : W.heap_type) st =
(* I31, struct, array and none have no other subtype *)
| _, (I31 | Type _ | Struct | Array | None_) -> false, st

let rec type_index_lub ty ty' st =
(* Find the LUB efficiently by taking advantage of the fact that
types are defined after their supertypes, making their variables
compare greater. *)
let c = Var.compare ty ty' in
if c > 0
then type_index_lub ty' ty st
else if c = 0
then Some ty
else
let type_field = Var.Hashtbl.find st.context.types ty' in
match type_field.supertype with
| None -> None
| Some ty'' ->
assert (Var.compare ty'' ty' < 0);
type_index_lub ty ty'' st

let heap_type_lub (ty : W.heap_type) (ty' : W.heap_type) =
match ty, ty' with
| (Func | Extern), _ | _, (Func | Extern) -> assert false
| None_, _ -> return ty'
| _, None_ | Struct, Struct | Array, Array -> return ty
| Any, _ | _, Any -> return W.Any
| Eq, _
| _, Eq
| (Struct | Array | Type _), I31
| I31, (Struct | Array | Type _)
| Struct, Array
| Array, Struct -> return (Eq : W.heap_type)
| Struct, Type t | Type t, Struct -> (
fun st ->
let type_field = Var.Hashtbl.find st.context.types t in
match type_field.typ with
| Struct _ -> W.Struct, st
| Array _ | Func _ -> W.Eq, st)
| Array, Type t | Type t, Array -> (
fun st ->
let type_field = Var.Hashtbl.find st.context.types t in
match type_field.typ with
| Array _ -> W.Struct, st
| Struct _ | Func _ -> W.Eq, st)
| Type t, Type t' -> (
let* r = fun st -> type_index_lub t t' st, st in
match r with
| Some t'' -> return (Type t'' : W.heap_type)
| None -> (
fun st ->
let type_field = Var.Hashtbl.find st.context.types t in
let type_field' = Var.Hashtbl.find st.context.types t' in
match type_field.typ, type_field'.typ with
| Struct _, Struct _ -> (Struct : W.heap_type), st
| Array _, Array _ -> W.Array, st
| (Array _ | Struct _ | Func _), (Array _ | Struct _ | Func _) -> W.Eq, st))
| I31, I31 -> return W.I31

let value_type_lub (ty : W.value_type) (ty' : W.value_type) =
match ty, ty' with
| Ref { nullable; typ }, Ref { nullable = nullable'; typ = typ' } ->
let* typ = heap_type_lub typ typ' in
return (W.Ref { nullable = nullable || nullable'; typ })
| _ -> assert false

let register_global name ?exported_name ?(constant = false) typ init st =
st.context.other_fields <-
W.Global { name; exported_name; typ; init } :: st.context.other_fields;
Expand Down Expand Up @@ -705,7 +767,7 @@ let init_code context = instrs context.init_code

let function_body ~context ~param_names ~body =
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
let (), st = body st in
let res, st = body st in
let local_count, body = st.var_count, List.rev st.instrs in
let local_types = Array.make local_count (Var.fresh (), None) in
List.iteri ~f:(fun i x -> local_types.(i) <- x, None) param_names;
Expand All @@ -723,4 +785,10 @@ let function_body ~context ~param_names ~body =
|> (fun a -> Array.sub a ~pos:param_count ~len:(Array.length a - param_count))
|> Array.to_list
in
locals, body
locals, res, body

let eval ~context e =
let st = { var_count = 0; vars = Var.Map.empty; instrs = []; context } in
let r, st = e st in
assert (st.var_count = 0 && List.is_empty st.instrs);
r
10 changes: 8 additions & 2 deletions compiler/lib-wasm/code_generation.mli
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ val register_type : string -> (unit -> type_def t) -> Wasm_ast.var t

val heap_type_sub : Wasm_ast.heap_type -> Wasm_ast.heap_type -> bool t

val value_type_lub : Wasm_ast.value_type -> Wasm_ast.value_type -> Wasm_ast.value_type t

val register_import :
?import_module:string -> name:string -> Wasm_ast.import_desc -> Wasm_ast.var t

Expand Down Expand Up @@ -195,13 +197,17 @@ val need_dummy_fun : cps:bool -> arity:int -> Code.Var.t t
val function_body :
context:context
-> param_names:Code.Var.t list
-> body:unit t
-> (Wasm_ast.var * Wasm_ast.value_type) list * Wasm_ast.instruction list
-> body:'a t
-> (Wasm_ast.var * Wasm_ast.value_type) list * 'a * Wasm_ast.instruction list

val variable_type : Code.Var.t -> Wasm_ast.value_type option t

val expression_type : Wasm_ast.expression -> Wasm_ast.value_type option t

val array_placeholder : Code.Var.t -> expression

val default_value :
Wasm_ast.value_type
-> (Wasm_ast.expression * Wasm_ast.value_type * Wasm_ast.ref_type option) t

val eval : context:context -> 'a t -> 'a
26 changes: 14 additions & 12 deletions compiler/lib-wasm/curry.ml
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ module Make (Target : Target_sig.S) = struct
loop m [] f None
in
let param_names = args @ [ f ] in
let locals, body = function_body ~context ~param_names ~body in
let locals, _, body = function_body ~context ~param_names ~body in
W.Function
{ name
; exported_name = None
; typ = None
; typ = Some (eval ~context (Type.function_type ~cps:false 1))
; signature = Type.func_type 1
; param_names
; locals
Expand Down Expand Up @@ -130,11 +130,11 @@ module Make (Target : Target_sig.S) = struct
push (Closure.curry_allocate ~cps:false ~arity m ~f:name' ~closure:f ~arg:x)
in
let param_names = [ x; f ] in
let locals, body = function_body ~context ~param_names ~body in
let locals, _, body = function_body ~context ~param_names ~body in
W.Function
{ name
; exported_name = None
; typ = None
; typ = Some (eval ~context (Type.function_type ~cps:false 1))
; signature = Type.func_type 1
; param_names
; locals
Expand Down Expand Up @@ -181,11 +181,11 @@ module Make (Target : Target_sig.S) = struct
loop m [] f None
in
let param_names = args @ [ f ] in
let locals, body = function_body ~context ~param_names ~body in
let locals, _, body = function_body ~context ~param_names ~body in
W.Function
{ name
; exported_name = None
; typ = None
; typ = Some (eval ~context (Type.function_type ~cps:true 1))
; signature = Type.func_type 2
; param_names
; locals
Expand Down Expand Up @@ -220,11 +220,11 @@ module Make (Target : Target_sig.S) = struct
instr (W.Return (Some c))
in
let param_names = [ x; cont; f ] in
let locals, body = function_body ~context ~param_names ~body in
let locals, _, body = function_body ~context ~param_names ~body in
W.Function
{ name
; exported_name = None
; typ = None
; typ = Some (eval ~context (Type.function_type ~cps:true 1))
; signature = Type.func_type 2
; param_names
; locals
Expand Down Expand Up @@ -264,7 +264,7 @@ module Make (Target : Target_sig.S) = struct
build_applies (load f) l)
in
let param_names = l @ [ f ] in
let locals, body = function_body ~context ~param_names ~body in
let locals, _, body = function_body ~context ~param_names ~body in
W.Function
{ name
; exported_name = None
Expand Down Expand Up @@ -312,7 +312,7 @@ module Make (Target : Target_sig.S) = struct
push (call ~cps:true ~arity:2 (load f) [ x; iterate ]))
in
let param_names = l @ [ f ] in
let locals, body = function_body ~context ~param_names ~body in
let locals, _, body = function_body ~context ~param_names ~body in
W.Function
{ name
; exported_name = None
Expand Down Expand Up @@ -347,11 +347,13 @@ module Make (Target : Target_sig.S) = struct
instr (W.Return (Some e))
in
let param_names = l @ [ f ] in
let locals, body = function_body ~context ~param_names ~body in
let locals, _, body = function_body ~context ~param_names ~body in
W.Function
{ name
; exported_name = None
; typ = None
; typ =
Some
(eval ~context (Type.function_type ~cps (if cps then arity - 1 else arity)))
; signature = Type.func_type arity
; param_names
; locals
Expand Down
62 changes: 47 additions & 15 deletions compiler/lib-wasm/gc_target.ml
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,35 @@ module Type = struct
let primitive_type n =
{ W.params = List.init ~len:n ~f:(fun _ -> value); result = [ value ] }

let func_type n = primitive_type (n + 1)

let function_type ~cps n =
let n = if cps then n + 1 else n in
register_type (Printf.sprintf "function_%d" n) (fun () ->
return { supertype = None; final = true; typ = W.Func (func_type n) })
let func_type ?(ret = value) n =
{ W.params = List.init ~len:(n + 1) ~f:(fun _ -> value); result = [ ret ] }

let rec function_type ~cps ?ret n =
let n' = if cps then n + 1 else n in
let ret_str =
match ret with
| None -> ""
| Some (W.Ref { nullable = false; typ }) -> (
match typ with
| Eq -> "_eq" (*ZZZ remove ret in that case*)
| I31 -> "_i31"
| Struct -> "_struct"
| Array -> "_array"
| None_ -> "_none"
| Type v -> (
match Code.Var.get_name v with
| None -> assert false
| Some name -> "_" ^ name)
| _ -> assert false)
| _ -> assert false
in
register_type (Printf.sprintf "function_%d%s" n' ret_str) (fun () ->
match ret with
| None -> return { supertype = None; final = false; typ = W.Func (func_type n') }
| Some ret ->
let* super = function_type ~cps n in
return
{ supertype = Some super; final = false; typ = W.Func (func_type ~ret n') })

let closure_common_fields ~cps =
let* fun_ty = function_type ~cps 1 in
Expand Down Expand Up @@ -606,6 +629,14 @@ module Value = struct
let int_asr = Arith.( asr )
end

let store_in_global ?(name = "const") c =
let name = Code.Var.fresh_n name in
let* typ = expression_type c in
let* () =
register_global name { mut = false; typ = Option.value ~default:Type.value typ } c
in
return (W.GlobalGet name)

module Memory = struct
let wasm_cast ty e =
let* e = e in
Expand Down Expand Up @@ -862,7 +893,9 @@ module Memory = struct
in
let* ty = Type.int32_type in
let* e = e in
return (W.StructNew (ty, [ GlobalGet int32_ops; e ]))
let e' = W.StructNew (ty, [ GlobalGet int32_ops; e ]) in
let* b = is_small_constant e in
if b then store_in_global e' else return e'

let box_int32 e = make_int32 ~kind:`Int32 e

Expand All @@ -880,7 +913,9 @@ module Memory = struct
in
let* ty = Type.int64_type in
let* e = e in
return (W.StructNew (ty, [ GlobalGet int64_ops; e ]))
let e' = W.StructNew (ty, [ GlobalGet int64_ops; e ]) in
let* b = is_small_constant e in
if b then store_in_global e' else return e'

let box_int64 e = make_int64 e

Expand All @@ -900,11 +935,6 @@ module Constant = struct
strings are encoded as a sequence of bytes in the wasm module. *)
let string_length_threshold = 64

let store_in_global ?(name = "const") c =
let name = Code.Var.fresh_n name in
let* () = register_global name { mut = false; typ = Type.value } c in
return (W.GlobalGet name)

let byte_string s =
let b = Buffer.create (String.length s) in
String.iter s ~f:(function
Expand Down Expand Up @@ -1037,13 +1067,15 @@ module Constant = struct
if b then return c else store_in_global c
| Const_named name -> store_in_global ~name c
| Mutated ->
let* typ = Type.string_type in
let name = Code.Var.fresh_n "const" in
let* placeholder = array_placeholder typ in
let* () =
register_global
~constant:true
name
{ mut = true; typ = Type.value }
(W.RefI31 (Const (I32 0l)))
{ mut = true; typ = Ref { nullable = false; typ = Type typ } }
placeholder
in
let* () = register_init_code (instr (W.GlobalSet (name, c))) in
return (W.GlobalGet name))
Expand Down
Loading
Loading