blob: 52e6f5d3bcb6c9d7c96ad0e639a21034836019c1 [file] [log] [blame] [edit]
(* Types *)
type t = string
type bits = string
type ('i8x16, 'i16x8, 'i32x4, 'i64x2, 'f32x4, 'f64x2) laneop =
| I8x16 of 'i8x16 | I16x8 of 'i16x8 | I32x4 of 'i32x4 | I64x2 of 'i64x2
| F32x4 of 'f32x4 | F64x2 of 'f64x2
type shape = (unit, unit, unit, unit, unit, unit) laneop
(* Basics *)
let bitwidth = 128
let bytewidth = bitwidth / 8
let zero = String.make bytewidth '\x00'
let of_bits x = x
let to_bits x = x
let num_lanes shape =
match shape with
| I8x16 _ -> 16
| I16x8 _ -> 8
| I32x4 _ -> 4
| I64x2 _ -> 2
| F32x4 _ -> 4
| F64x2 _ -> 2
let type_of_lane = function
| I8x16 _ | I16x8 _ | I32x4 _ -> Types.I32T
| I64x2 _ -> Types.I64T
| F32x4 _ -> Types.F32T
| F64x2 _ -> Types.F64T
(* Shape-based operations *)
module Conversion (Lane : sig type t end) =
struct
module type T =
sig
val shape : shape
val to_lanes : t -> Lane.t list
val of_lanes : Lane.t list -> t
end
end
module type IntShape =
sig
type lane
val num_lanes : int
val to_lanes : t -> lane list
val of_lanes : lane list -> t
val splat : lane -> t
val extract_lane : int -> t -> lane
val replace_lane : int -> t -> lane -> t
val eq : t -> t -> t
val ne : t -> t -> t
val lt_s : t -> t -> t
val lt_u : t -> t -> t
val le_s : t -> t -> t
val le_u : t -> t -> t
val gt_s : t -> t -> t
val gt_u : t -> t -> t
val ge_s : t -> t -> t
val ge_u : t -> t -> t
val abs : t -> t
val neg : t -> t
val popcnt : t -> t
val add : t -> t -> t
val sub : t -> t -> t
val min_s : t -> t -> t
val min_u : t -> t -> t
val max_s : t -> t -> t
val max_u : t -> t -> t
val mul : t -> t -> t
val avgr_u : t -> t -> t
val any_true : t -> bool
val all_true : t -> bool
val bitmask : t -> Int32.t
val shl : t -> I32.t -> t
val shr_s : t -> I32.t -> t
val shr_u : t -> I32.t -> t
val add_sat_s : t -> t -> t
val add_sat_u : t -> t -> t
val sub_sat_s : t -> t -> t
val sub_sat_u : t -> t -> t
val q15mulr_sat_s : t -> t -> t
end
module MakeIntShape (IXX : Ixx.T) (Cvt : Conversion(IXX).T) :
IntShape with type lane = IXX.t =
struct
type lane = IXX.t
let num_lanes = num_lanes Cvt.shape
let of_lanes = Cvt.of_lanes
let to_lanes = Cvt.to_lanes
let unop f x = of_lanes (List.map f (to_lanes x))
let unopi f x = of_lanes (List.mapi f (to_lanes x))
let binop f x y = of_lanes (List.map2 f (to_lanes x) (to_lanes y))
let reduceop f a s = List.fold_left (fun a b -> f a (b <> IXX.zero)) a (to_lanes s)
let cmp f x y = if f x y then IXX.of_int_s (-1) else IXX.zero
let splat x = of_lanes (List.init num_lanes (fun i -> x))
let extract_lane i s = List.nth (to_lanes s) i
let replace_lane i v x = unopi (fun j y -> if j = i then x else y) v
let eq = binop (cmp IXX.eq)
let ne = binop (cmp IXX.ne)
let lt_s = binop (cmp IXX.lt_s)
let lt_u = binop (cmp IXX.lt_u)
let le_s = binop (cmp IXX.le_s)
let le_u = binop (cmp IXX.le_u)
let gt_s = binop (cmp IXX.gt_s)
let gt_u = binop (cmp IXX.gt_u)
let ge_s = binop (cmp IXX.ge_s)
let ge_u = binop (cmp IXX.ge_u)
let abs = unop IXX.abs
let neg = unop IXX.neg
let popcnt = unop IXX.popcnt
let add = binop IXX.add
let sub = binop IXX.sub
let mul = binop IXX.mul
let choose f x y = if f x y then x else y
let min_s = binop (choose IXX.le_s)
let min_u = binop (choose IXX.le_u)
let max_s = binop (choose IXX.ge_s)
let max_u = binop (choose IXX.ge_u)
(* The result of avgr_u will not overflow this type, but the intermediate might,
* so have the Int type implement it so they can extend it accordingly *)
let avgr_u = binop IXX.avgr_u
let any_true = reduceop (||) false
let all_true = reduceop (&&) true
(* Extract top bits using signed-comparision with zero *)
let bitmask x =
let xs = to_lanes x in
let negs = List.map (fun x -> if IXX.(lt_s x zero) then Int32.one else Int32.zero) xs in
List.fold_right (fun a b -> Int32.(logor a (shift_left b 1))) negs Int32.zero
let shl v s =
let shift = IXX.of_int_u (Int32.to_int s) in
unop (fun a -> IXX.shl a shift) v
let shr_s v s =
let shift = IXX.of_int_u (Int32.to_int s) in
unop (fun a -> IXX.shr_s a shift) v
let shr_u v s =
let shift = IXX.of_int_u (Int32.to_int s) in
unop (fun a -> IXX.shr_u a shift) v
let add_sat_s = binop IXX.add_sat_s
let add_sat_u = binop IXX.add_sat_u
let sub_sat_s = binop IXX.sub_sat_s
let sub_sat_u = binop IXX.sub_sat_u
(* The intermediate will overflow lane.t, so have Int implement this. *)
let q15mulr_sat_s = binop IXX.q15mulr_sat_s
end
module type FloatShape =
sig
type lane
val num_lanes : int
val to_lanes : t -> lane list
val of_lanes : lane list -> t
val splat : lane -> t
val extract_lane : int -> t -> lane
val replace_lane : int -> t -> lane -> t
val eq : t -> t -> t
val ne : t -> t -> t
val lt : t -> t -> t
val le : t -> t -> t
val gt : t -> t -> t
val ge : t -> t -> t
val abs : t -> t
val neg : t -> t
val sqrt : t -> t
val ceil : t -> t
val floor : t -> t
val trunc : t -> t
val nearest : t -> t
val add : t -> t -> t
val sub : t -> t -> t
val mul : t -> t -> t
val div : t -> t -> t
val fma : t -> t -> t -> t
val fnma : t -> t -> t -> t
val min : t -> t -> t
val max : t -> t -> t
val pmin : t -> t -> t
val pmax : t -> t -> t
end
module MakeFloatShape (FXX : Fxx.T) (Cvt : Conversion(FXX).T) :
FloatShape with type lane = FXX.t =
struct
type lane = FXX.t
let num_lanes = num_lanes Cvt.shape
let of_lanes = Cvt.of_lanes
let to_lanes = Cvt.to_lanes
let unop f x = of_lanes (List.map f (to_lanes x))
let unopi f x = of_lanes (List.mapi f (to_lanes x))
let binop f x y = of_lanes (List.map2 f (to_lanes x) (to_lanes y))
let all_ones = FXX.of_float (Int64.float_of_bits (Int64.minus_one))
let cmp f x y = if f x y then all_ones else FXX.zero
let splat x = of_lanes (List.init num_lanes (fun i -> x))
let extract_lane i s = List.nth (to_lanes s) i
let replace_lane i v x = unopi (fun j y -> if j = i then x else y) v
let eq = binop (cmp FXX.eq)
let ne = binop (cmp FXX.ne)
let lt = binop (cmp FXX.lt)
let le = binop (cmp FXX.le)
let gt = binop (cmp FXX.gt)
let ge = binop (cmp FXX.ge)
let abs = unop FXX.abs
let neg = unop FXX.neg
let sqrt = unop FXX.sqrt
let ceil = unop FXX.ceil
let floor = unop FXX.floor
let trunc = unop FXX.trunc
let nearest = unop FXX.nearest
let add = binop FXX.add
let sub = binop FXX.sub
let mul = binop FXX.mul
let div = binop FXX.div
let fma x y z =
of_lanes (Lib.List.map3 FXX.fma (to_lanes x) (to_lanes y) (to_lanes z))
let fnma x y z = fma (unop FXX.neg x) y z
let min = binop FXX.min
let max = binop FXX.max
let pmin = binop (fun x y -> if FXX.lt y x then y else x)
let pmax = binop (fun x y -> if FXX.lt x y then y else x)
end
module I8x16 = MakeIntShape (I8)
(struct
let shape = I8x16 ()
let to_lanes s =
List.init 16 (fun i -> I8.of_int_s (Bytes.get_int8 (Bytes.of_string s) i))
let of_lanes fs =
assert (List.length fs = 16);
let b = Bytes.create bytewidth in
List.iteri (fun i f -> Bytes.set_int8 b i (I8.to_int_s f)) fs;
Bytes.to_string b
end)
module I16x8 = MakeIntShape (I16)
(struct
let shape = I16x8 ()
let to_lanes s =
List.init 8 (fun i -> I16.of_int_s (Bytes.get_int16_le (Bytes.of_string s) (i*2)))
let of_lanes fs =
assert (List.length fs = 8);
let b = Bytes.create bytewidth in
List.iteri (fun i f -> Bytes.set_int16_le b (i*2) (I16.to_int_s f)) fs;
Bytes.to_string b
end)
module I32x4 = MakeIntShape (I32)
(struct
let shape = I32x4 ()
let to_lanes s =
List.init 4 (fun i -> Bytes.get_int32_le (Bytes.of_string s) (i*4))
let of_lanes fs =
assert (List.length fs = 4);
let b = Bytes.create bytewidth in
List.iteri (fun i f -> Bytes.set_int32_le b (i*4) f) fs;
Bytes.to_string b
end)
module I64x2 = MakeIntShape (I64)
(struct
let shape = I64x2 ()
let to_lanes s =
List.init 2 (fun i -> Bytes.get_int64_le (Bytes.of_string s) (i*8))
let of_lanes fs =
assert (List.length fs = 2);
let b = Bytes.create bytewidth in
List.iteri (fun i f -> Bytes.set_int64_le b (i*8) f) fs;
Bytes.to_string b
end)
module F32x4 = MakeFloatShape (F32)
(struct
let shape = F32x4 ()
let to_lanes s =
List.init 4 (fun i -> F32.of_bits (Bytes.get_int32_le (Bytes.of_string s) (i*4)))
let of_lanes fs =
assert (List.length fs = 4);
let b = Bytes.create bytewidth in
List.iteri (fun i f -> Bytes.set_int32_le b (i*4) (F32.to_bits f)) fs;
Bytes.to_string b
end)
module F64x2 = MakeFloatShape (F64)
(struct
let shape = F64x2 ()
let to_lanes s =
List.init 2 (fun i -> F64.of_bits (Bytes.get_int64_le (Bytes.of_string s) (i*8)))
let of_lanes fs =
assert (List.length fs = 2);
let b = Bytes.create bytewidth in
List.iteri (fun i f -> Bytes.set_int64_le b (i*8) (F64.to_bits f)) fs;
Bytes.to_string b
end)
(* Special shapes *)
module V1x128 =
struct
let unop f x = I64x2.of_lanes (List.map f (I64x2.to_lanes x))
let binop f x y =
I64x2.of_lanes (List.map2 f (I64x2.to_lanes x) (I64x2.to_lanes y))
let not_ = unop I64.not_
let and_ = binop I64.and_
let or_ = binop I64.or_
let xor = binop I64.xor
let andnot = binop (fun x y -> I64.and_ x (I64.not_ y))
let bitselect v1 v2 c =
let v2_andnot_c = andnot v2 c in
let v1_and_c = binop I64.and_ v1 c in
binop I64.or_ v1_and_c v2_andnot_c
end
module V8x16 =
struct
let swizzle v1 v2 =
let ns = I8x16.to_lanes v1 in
let is = I8x16.to_lanes v2 in
let select i =
Option.value (List.nth_opt ns (I8.to_int_u i)) ~default: I8.zero
in I8x16.of_lanes (List.map select is)
let shuffle is v1 v2 =
let ns = I8x16.to_lanes v1 @ I8x16.to_lanes v2 in
I8x16.of_lanes (List.map (List.nth ns) is)
end
(* Conversions *)
let narrow to_lanes cvt of_lanes x y =
of_lanes (List.map cvt (to_lanes x @ to_lanes y))
module I8x16_convert =
struct
let narrow_s = narrow I16x8.to_lanes Convert.I8_.narrow_sat_i16_s I8x16.of_lanes
let narrow_u = narrow I16x8.to_lanes Convert.I8_.narrow_sat_i16_u I8x16.of_lanes
end
module I16x8_convert =
struct
let narrow_s = narrow I32x4.to_lanes Convert.I16_.narrow_sat_i32_s I16x8.of_lanes
let narrow_u = narrow I32x4.to_lanes Convert.I16_.narrow_sat_i32_u I16x8.of_lanes
let extend take_or_drop ext x =
I16x8.of_lanes (List.map ext (take_or_drop 8 (I8x16.to_lanes x)))
let extend_low_s = extend Lib.List.take Convert.I16_.extend_i8_s
let extend_high_s = extend Lib.List.drop Convert.I16_.extend_i8_s
let extend_low_u = extend Lib.List.take Convert.I16_.extend_i8_u
let extend_high_u = extend Lib.List.drop Convert.I16_.extend_i8_u
let extmul_low_s x y = I16x8.mul (extend_low_s x) (extend_low_s y)
let extmul_high_s x y = I16x8.mul (extend_high_s x) (extend_high_s y)
let extmul_low_u x y = I16x8.mul (extend_low_u x) (extend_low_u y)
let extmul_high_u x y = I16x8.mul (extend_high_u x) (extend_high_u y)
let extadd ext x y = I16.add (ext x) (ext y)
let extadd_pairwise_s x =
I16x8.of_lanes (Lib.List.map_pairwise (extadd Convert.I16_.extend_i8_s) (I8x16.to_lanes x))
let extadd_pairwise_u x =
I16x8.of_lanes (Lib.List.map_pairwise (extadd Convert.I16_.extend_i8_u) (I8x16.to_lanes x))
let dot_s x y =
let xs = List.map Convert.I16_.extend_i8_s (I8x16.to_lanes x) in
let ys = List.map Convert.I16_.extend_i8_s (I8x16.to_lanes y) in
let rec dot xs ys =
match xs, ys with
| x1::x2::xs', y1::y2::ys' ->
I16.(add (mul x1 y1) (mul x2 y2)) :: dot xs' ys'
| [], [] -> []
| _, _ -> assert false
in I16x8.of_lanes (dot xs ys)
end
module I32x4_convert =
struct
let convert f v = I32x4.of_lanes (List.map f (F32x4.to_lanes v))
let trunc_sat_f32x4_s = convert Convert.I32_.trunc_sat_f32_s
let trunc_sat_f32x4_u = convert Convert.I32_.trunc_sat_f32_u
let convert_zero f v =
I32x4.of_lanes (List.map f (F64x2.to_lanes v) @ I32.[zero; zero])
let trunc_sat_f64x2_s_zero = convert_zero Convert.I32_.trunc_sat_f64_s
let trunc_sat_f64x2_u_zero = convert_zero Convert.I32_.trunc_sat_f64_u
let extend take_or_drop ext x =
I32x4.of_lanes (List.map ext (take_or_drop 4 (I16x8.to_lanes x)))
let extend_low_s = extend Lib.List.take Convert.I32_.extend_i16_s
let extend_high_s = extend Lib.List.drop Convert.I32_.extend_i16_s
let extend_low_u = extend Lib.List.take Convert.I32_.extend_i16_u
let extend_high_u = extend Lib.List.drop Convert.I32_.extend_i16_u
let dot_s x y =
let xs = List.map Convert.I32_.extend_i16_s (I16x8.to_lanes x) in
let ys = List.map Convert.I32_.extend_i16_s (I16x8.to_lanes y) in
let rec dot xs ys =
match xs, ys with
| x1::x2::xss, y1::y2::yss ->
Int32.(add (mul x1 y1) (mul x2 y2)) :: dot xss yss
| [], [] -> []
| _, _ -> assert false
in I32x4.of_lanes (dot xs ys)
let dot_add_s x y z =
let xs = List.map Convert.I32_.extend_i8_s (I8x16.to_lanes x) in
let ys = List.map Convert.I32_.extend_i8_s (I8x16.to_lanes y) in
let rec dot xs ys =
match xs, ys with
| x1::x2::x3::x4::xs', y1::y2::y3::y4::ys' ->
Int32.(add
(add (mul x1 y1) (mul x2 y2))
(add (mul x3 y3) (mul x4 y4))
) :: dot xs' ys'
| [], [] -> []
| _, _ -> assert false
in I32x4.add (I32x4.of_lanes (dot xs ys)) z
let extmul_low_s x y = I32x4.mul (extend_low_s x) (extend_low_s y)
let extmul_high_s x y = I32x4.mul (extend_high_s x) (extend_high_s y)
let extmul_low_u x y = I32x4.mul (extend_low_u x) (extend_low_u y)
let extmul_high_u x y = I32x4.mul (extend_high_u x) (extend_high_u y)
let extadd ext x y = Int32.add (ext x) (ext y)
let extadd_pairwise_s x =
I32x4.of_lanes (Lib.List.map_pairwise (extadd Convert.I32_.extend_i16_s) (I16x8.to_lanes x))
let extadd_pairwise_u x =
I32x4.of_lanes (Lib.List.map_pairwise (extadd Convert.I32_.extend_i16_u) (I16x8.to_lanes x))
end
module I64x2_convert =
struct
let extend take_or_drop ext x =
I64x2.of_lanes (List.map ext (take_or_drop 2 (I32x4.to_lanes x)))
let extend_low_s = extend Lib.List.take Convert.I64_.extend_i32_s
let extend_high_s = extend Lib.List.drop Convert.I64_.extend_i32_s
let extend_low_u = extend Lib.List.take Convert.I64_.extend_i32_u
let extend_high_u = extend Lib.List.drop Convert.I64_.extend_i32_u
let extmul_low_s x y = I64x2.mul (extend_low_s x) (extend_low_s y)
let extmul_high_s x y = I64x2.mul (extend_high_s x) (extend_high_s y)
let extmul_low_u x y = I64x2.mul (extend_low_u x) (extend_low_u y)
let extmul_high_u x y = I64x2.mul (extend_high_u x) (extend_high_u y)
end
module F32x4_convert =
struct
let convert f v = F32x4.of_lanes (List.map f (I32x4.to_lanes v))
let convert_i32x4_s = convert Convert.F32_.convert_i32_s
let convert_i32x4_u = convert Convert.F32_.convert_i32_u
let demote_f64x2_zero v =
F32x4.of_lanes
(List.map Convert.F32_.demote_f64 (F64x2.to_lanes v) @ F32.[zero; zero])
end
module F64x2_convert =
struct
let convert f v =
F64x2.of_lanes (List.map f (Lib.List.take 2 (I32x4.to_lanes v)))
let convert_i32x4_s = convert Convert.F64_.convert_i32_s
let convert_i32x4_u = convert Convert.F64_.convert_i32_u
let promote_low_f32x4 v =
F64x2.of_lanes
(List.map Convert.F64_.promote_f32 (Lib.List.take 2 (F32x4.to_lanes v)))
end
(* String conversion *)
let to_string s =
String.concat " " (List.map I32.to_string_s (I32x4.to_lanes s))
let to_hex_string s =
String.concat " " (List.map I32.to_hex_string (I32x4.to_lanes s))
let of_strings shape ss =
if List.length ss <> num_lanes shape then
invalid_arg "wrong length";
let open Bytes in
let b = create bytewidth in
(match shape with
| I8x16 () ->
List.iteri (fun i s -> set_uint8 b i (I8.to_int_u (I8.of_string s))) ss
| I16x8 () ->
List.iteri (fun i s -> set_int16_le b (i * 2) (I16.to_int_u (I16.of_string s))) ss
| I32x4 () ->
List.iteri (fun i s -> set_int32_le b (i * 4) (I32.of_string s)) ss
| I64x2 () ->
List.iteri (fun i s -> set_int64_le b (i * 8) (I64.of_string s)) ss
| F32x4 () ->
List.iteri (fun i s -> set_int32_le b (i * 4) (F32.to_bits (F32.of_string s))) ss
| F64x2 () ->
List.iteri (fun i s -> set_int64_le b (i * 8) (F64.to_bits (F64.of_string s))) ss
);
to_string b
let string_of_shape = function
| I8x16 _ -> "i8x16"
| I16x8 _ -> "i16x8"
| I32x4 _ -> "i32x4"
| I64x2 _ -> "i64x2"
| F32x4 _ -> "f32x4"
| F64x2 _ -> "f64x2"