| module type RepType = |
| sig |
| type t |
| |
| val bitwidth : int |
| val mantissa : int |
| |
| val zero : t |
| val min_int : t |
| val max_int : t |
| |
| val pos_nan : t |
| val neg_nan : t |
| val bare_nan : t |
| |
| val bits_of_float : float -> t |
| val float_of_bits : t -> float |
| val of_string : string -> t |
| val to_string : t -> string |
| val to_hex_string : t -> string |
| |
| val lognot : t -> t |
| val logand : t -> t -> t |
| val logor : t -> t -> t |
| val logxor : t -> t -> t |
| end |
| |
| module type T = |
| sig |
| type t |
| type bits |
| val bitwidth : int |
| val mantissa : int |
| val exponent : int |
| val zero : t |
| val pos_inf : t |
| val neg_inf : t |
| val pos_nan : t |
| val neg_nan : t |
| val is_inf : t -> bool |
| val is_nan : t -> bool |
| val of_float : float -> t |
| val to_float : t -> float |
| val of_string : string -> t |
| val to_string : t -> string |
| val to_hex_string : t -> string |
| val of_bits : bits -> t |
| val to_bits : t -> bits |
| val add : t -> t -> t |
| val sub : t -> t -> t |
| val mul : t -> t -> t |
| val div : t -> t -> t |
| val fma : t -> t -> t -> t |
| val sqrt : t -> t |
| val min : t -> t -> t |
| val max : t -> t -> t |
| val ceil : t -> t |
| val floor : t -> t |
| val trunc : t -> t |
| val nearest : t -> t |
| val abs : t -> t |
| val neg : t -> t |
| val copysign : t -> t -> t |
| val eq : t -> t -> bool |
| val ne : t -> t -> bool |
| val lt : t -> t -> bool |
| val le : t -> t -> bool |
| val gt : t -> t -> bool |
| val ge : t -> t -> bool |
| end |
| |
| module Make (Rep : RepType) : T with type bits = Rep.t = |
| struct |
| let _ = assert (Rep.mantissa < Rep.bitwidth - 2) |
| |
| type t = Rep.t |
| type bits = Rep.t |
| |
| let bitwidth = Rep.bitwidth |
| let mantissa = Rep.mantissa |
| let exponent = bitwidth - mantissa - 1 |
| |
| let pos_inf = Rep.bits_of_float (1.0 /. 0.0) |
| let neg_inf = Rep.bits_of_float (-. (1.0 /. 0.0)) |
| let pos_nan = Rep.pos_nan |
| let neg_nan = Rep.neg_nan |
| let bare_nan = Rep.bare_nan |
| |
| let of_float = Rep.bits_of_float |
| let to_float = Rep.float_of_bits |
| |
| let of_bits x = x |
| let to_bits x = x |
| |
| let is_inf x = x = pos_inf || x = neg_inf |
| let is_nan x = let xf = Rep.float_of_bits x in xf <> xf |
| |
| (* |
| * When the result of an arithmetic operation is NaN, the most significant |
| * bit of the significand field is set. |
| *) |
| let canonicalize_nan x = Rep.logor x Rep.pos_nan |
| |
| (* |
| * When the result of a binary operation is NaN, the resulting NaN is computed |
| * from one of the NaN inputs, if there is one. If both are NaN, one is |
| * selected nondeterminstically. If neither, we use a default NaN value. |
| *) |
| let determine_binary_nan x y = |
| (* |
| * TODO: There are two nondeterministic things we could do here. When both |
| * x and y are NaN, we can nondeterministically pick which to return. And |
| * when neither is NaN, we can nondeterministically pick whether to return |
| * pos_nan or neg_nan. |
| *) |
| let nan = |
| if is_nan x then x else |
| if is_nan y then y else Rep.pos_nan |
| in canonicalize_nan nan |
| |
| (* |
| * When the result of a unary operation is NaN, the resulting NaN is computed |
| * from one of the NaN input, if there it is NaN. Otherwise, we use a default |
| * NaN value. |
| *) |
| let determine_unary_nan x = |
| (* |
| * TODO: There is one nondeterministic thing we could do here. When the |
| * operand is not NaN, we can nondeterministically pick whether to return |
| * pos_nan or neg_nan. |
| *) |
| let nan = if is_nan x then x else Rep.pos_nan in |
| canonicalize_nan nan |
| |
| let binary x op y = |
| let xf = to_float x in |
| let yf = to_float y in |
| let t = op xf yf in |
| if t = t then of_float t else determine_binary_nan x y |
| |
| let unary op x = |
| let t = op (to_float x) in |
| if t = t then of_float t else determine_unary_nan x |
| |
| let zero = of_float 0.0 |
| |
| let add x y = binary x (+.) y |
| let sub x y = binary x (-.) y |
| let mul x y = binary x ( *.) y |
| let div x y = binary x (/.) y |
| |
| let fma x y z = |
| let xf = to_float x in |
| let yf = to_float y in |
| let zf = to_float z in |
| let t = Float.fma xf yf zf in |
| if t = t then of_float t else determine_binary_nan x y |
| |
| let sqrt x = unary Stdlib.sqrt x |
| |
| let ceil x = unary Stdlib.ceil x |
| let floor x = unary Stdlib.floor x |
| |
| let trunc x = |
| let xf = to_float x in |
| (* preserve the sign of zero *) |
| if xf = 0.0 then x else |
| (* trunc is either ceil or floor depending on which one is toward zero *) |
| let f = if xf < 0.0 then Stdlib.ceil xf else Stdlib.floor xf in |
| let result = of_float f in |
| if is_nan result then determine_unary_nan result else result |
| |
| let nearest x = |
| let xf = to_float x in |
| (* preserve the sign of zero *) |
| if xf = 0.0 then x else |
| (* nearest is either ceil or floor depending on which is nearest or even *) |
| let u = Stdlib.ceil xf in |
| let d = Stdlib.floor xf in |
| let um = abs_float (xf -. u) in |
| let dm = abs_float (xf -. d) in |
| let u_or_d = |
| um < dm || |
| um = dm && let h = u /. 2. in Stdlib.floor h = h |
| in |
| let f = if u_or_d then u else d in |
| let result = of_float f in |
| if is_nan result then determine_unary_nan result else result |
| |
| let min x y = |
| let xf = to_float x in |
| let yf = to_float y in |
| (* min -0 0 is -0 *) |
| if xf = yf then Rep.logor x y else |
| if xf < yf then x else |
| if xf > yf then y else |
| determine_binary_nan x y |
| |
| let max x y = |
| let xf = to_float x in |
| let yf = to_float y in |
| (* max -0 0 is 0 *) |
| if xf = yf then Rep.logand x y else |
| if xf > yf then x else |
| if xf < yf then y else |
| determine_binary_nan x y |
| |
| (* abs, neg, copysign are purely bitwise operations, even on NaN values *) |
| let abs x = |
| Rep.logand x Rep.max_int |
| |
| let neg x = |
| Rep.logxor x Rep.min_int |
| |
| let copysign x y = |
| Rep.logor (abs x) (Rep.logand y Rep.min_int) |
| |
| let eq x y = (to_float x = to_float y) |
| let ne x y = (to_float x <> to_float y) |
| let lt x y = (to_float x < to_float y) |
| let gt x y = (to_float x > to_float y) |
| let le x y = (to_float x <= to_float y) |
| let ge x y = (to_float x >= to_float y) |
| |
| (* |
| * Compare mantissa of two floats in string representation (hex or dec). |
| * This is a gross hack to detect rounding during parsing of floats. |
| *) |
| let is_hex c = ('0' <= c && c <= '9') || ('A' <= c && c <= 'F') |
| let is_exp hex c = (c = if hex then 'P' else 'E') |
| let at_end hex s i = (i = String.length s) || is_exp hex s.[i] |
| |
| let rec skip_non_hex s i = (* to skip sign, 'x', '.', '_', etc. *) |
| if at_end true s i || is_hex s.[i] then i else skip_non_hex s (i + 1) |
| |
| let rec skip_zeroes s i = |
| let i' = skip_non_hex s i in |
| if at_end true s i' || s.[i'] <> '0' then i' else skip_zeroes s (i' + 1) |
| |
| let rec compare_mantissa_str' hex s1 i1 s2 i2 = |
| let i1' = skip_non_hex s1 i1 in |
| let i2' = skip_non_hex s2 i2 in |
| match at_end hex s1 i1', at_end hex s2 i2' with |
| | true, true -> 0 |
| | true, false -> if at_end hex s2 (skip_zeroes s2 i2') then 0 else -1 |
| | false, true -> if at_end hex s1 (skip_zeroes s1 i1') then 0 else +1 |
| | false, false -> |
| match compare s1.[i1'] s2.[i2'] with |
| | 0 -> compare_mantissa_str' hex s1 (i1' + 1) s2 (i2' + 1) |
| | n -> n |
| |
| let compare_mantissa_str hex s1 s2 = |
| let s1' = String.uppercase_ascii s1 in |
| let s2' = String.uppercase_ascii s2 in |
| compare_mantissa_str' hex s1' (skip_zeroes s1' 0) s2' (skip_zeroes s2' 0) |
| |
| (* |
| * Convert a string to a float in target precision by going through |
| * OCaml's 64 bit floats. This may incur double rounding errors in edge |
| * cases, i.e., when rounding to target precision involves a tie that |
| * was created by earlier rounding during parsing to float. If both |
| * end up rounding in the same direction, we would "over round". |
| * This function tries to detect this case and correct accordingly. |
| *) |
| let float_of_string_prevent_double_rounding s = |
| (* First parse to a 64 bit float. *) |
| let z = float_of_string s in |
| (* If value is already infinite we are done. *) |
| if abs_float z = 1.0 /. 0.0 then z else |
| (* Else, bit twiddling to see what rounding to target precision will do. *) |
| let open Int64 in |
| let bits = bits_of_float z in |
| let lsb = shift_left 1L (52 - Rep.mantissa) in |
| (* Check for tie, i.e. whether the bits right of target LSB are 10000... *) |
| let tie = shift_right lsb 1 in |
| let mask = lognot (shift_left (-1L) (52 - Rep.mantissa)) in |
| (* If we have no tie, we are good. *) |
| if logand bits mask <> tie then z else |
| (* Else, define epsilon to be the value of the tie bit. *) |
| let exp = float_of_bits (logand bits 0xfff0_0000_0000_0000L) in |
| let eps = float_of_bits (logor tie (bits_of_float exp)) -. exp in |
| (* Convert 64 bit float back to string to compare to input. *) |
| let hex = String.contains s 'x' in |
| let s' = |
| if not hex then Printf.sprintf "%.*g" (String.length s) z else |
| let m = logor (logand bits 0xf_ffff_ffff_ffffL) 0x10_0000_0000_0000L in |
| (* Shift mantissa to match msb position in most significant hex digit *) |
| let i = skip_zeroes (String.uppercase_ascii s) 0 in |
| if i = String.length s then Printf.sprintf "%.*g" (String.length s) z else |
| let sh = |
| match s.[i] with '1' -> 0 | '2'..'3' -> 1 | '4'..'7' -> 2 | _ -> 3 in |
| Printf.sprintf "%Lx" (shift_left m sh) |
| in |
| (* - If mantissa became larger, float was rounded up to tie already; |
| * round-to-even might round up again: sub epsilon to round down. |
| * - If mantissa became smaller, float was rounded down to tie already; |
| * round-to-even migth round down again: add epsilon to round up. |
| * - If tie is not the result of prior rounding, then we are good. |
| *) |
| match compare_mantissa_str hex s s' with |
| | -1 -> z -. eps |
| | +1 -> z +. eps |
| | _ -> z |
| |
| let of_signless_string s = |
| if s = "inf" then |
| pos_inf |
| else if s = "nan" then |
| pos_nan |
| else if String.length s > 6 && String.sub s 0 6 = "nan:0x" then |
| let x = Rep.of_string (String.sub s 4 (String.length s - 4)) in |
| if x = Rep.zero then |
| raise (Failure "nan payload must not be zero") |
| else if Rep.logand x bare_nan <> Rep.zero then |
| raise (Failure "nan payload must not overlap with exponent bits") |
| else if x < Rep.zero then |
| raise (Failure "nan payload must not overlap with sign bit") |
| else |
| Rep.logor x bare_nan |
| else |
| let s' = String.concat "" (String.split_on_char '_' s) in |
| let x = of_float (float_of_string_prevent_double_rounding s') in |
| if is_inf x then failwith "of_string" else x |
| |
| let of_string s = |
| if s = "" then |
| failwith "of_string" |
| else if s.[0] = '+' || s.[0] = '-' then |
| let x = of_signless_string (String.sub s 1 (String.length s - 1)) in |
| if s.[0] = '+' then x else neg x |
| else |
| of_signless_string s |
| |
| (* String conversion that groups digits for readability *) |
| |
| let is_digit c = '0' <= c && c <= '9' |
| let is_hex_digit c = is_digit c || 'a' <= c && c <= 'f' |
| |
| let rec add_digits buf s i j k n = |
| if i < j then begin |
| if k = 0 then Buffer.add_char buf '_'; |
| Buffer.add_char buf s.[i]; |
| add_digits buf s (i + 1) j ((k + n - 1) mod n) n |
| end |
| |
| let group_digits is_digit n s = |
| let isnt_digit c = not (is_digit c) in |
| let len = String.length s in |
| let x = Lib.Option.get (Lib.String.find_from_opt ((=) 'x') s 0) 0 in |
| let mant = Lib.Option.get (Lib.String.find_from_opt is_digit s x) len in |
| let point = Lib.Option.get (Lib.String.find_from_opt isnt_digit s mant) len in |
| let frac = Lib.Option.get (Lib.String.find_from_opt is_digit s point) len in |
| let exp = Lib.Option.get (Lib.String.find_from_opt isnt_digit s frac) len in |
| let buf = Buffer.create (len*(n+1)/n) in |
| Buffer.add_substring buf s 0 mant; |
| add_digits buf s mant point ((point - mant) mod n + n) n; |
| Buffer.add_substring buf s point (frac - point); |
| add_digits buf s frac exp n n; |
| Buffer.add_substring buf s exp (len - exp); |
| Buffer.contents buf |
| |
| let to_string' convert is_digit n x = |
| (if x < Rep.zero then "-" else "") ^ |
| if is_nan x then |
| let payload = Rep.logand (abs x) (Rep.lognot bare_nan) in |
| "nan:0x" ^ group_digits is_hex_digit 4 (Rep.to_hex_string payload) |
| else |
| let s = convert (to_float (abs x)) in |
| group_digits is_digit n |
| (if s.[String.length s - 1] = '.' then s ^ "0" else s) |
| |
| let to_string = to_string' (Printf.sprintf "%.17g") is_digit 3 |
| let to_hex_string x = |
| if is_inf x then to_string x else |
| to_string' (Printf.sprintf "%h") is_hex_digit 4 x |
| end |