Skip to content

Instantly share code, notes, and snippets.

@stedolan
Last active August 20, 2025 15:33
Show Gist options
  • Save stedolan/8e7cce54644251d9d00a36493fd33d6a to your computer and use it in GitHub Desktop.
Save stedolan/8e7cce54644251d9d00a36493fd33d6a to your computer and use it in GitHub Desktop.

Generating good code for bitfields

Here's a definition of the packed 16-bit RGB565 pixel format as a C struct:

typedef struct { unsigned r : 5, g : 6, b : 5; } pixel;

and a couple of functions that operate on it:

void process_green(short green);

void process_pixel(pixel px) {
    process_green(px.g << 2);
}

void process_grb(unsigned g, unsigned r, unsigned b) {
    process_pixel((pixel){r,g,b});
}

These functions both involve many bit-packing operations: process_green must extract the 6-bit g field, shift it left to become 8 bits wide, then convert it to short to pass it on. On the x86-64 calling convention, short arguments have 16 bits of data, which should be sign-extended to 32 bits and then zero-extended from there to 64 bits. (This seems, and is, very odd. However, it's what you naturally get on x86 using the 16-to-32 bit sign extension instructions). process_grb is similar, needing an extra few shifting and masking operations to assemble the packed bits. I've chosen to use an unusual grb order for the channels here, so that the g parameter stays in the rdi register all the time, making the resulting assembly slightly easier to read.

Clang/LLVM 20 and GCC 15 both generate good code for these functions. Below is the Clang version, but GCC is almost identical:

process_pixel(pixel):
        shr     edi, 3
        and     edi, 252
        jmp     process_green(short)@PLT

process_grb(unsigned int, unsigned int, unsigned int):
        shl     edi, 2
        movzx   edi, dil
        jmp     process_green(short)@PLT

All of the bit-packing has collapsed to the optimal sequence of a shift and a mask. (The movzx instruction is a zero-extension instruction from 8 bits, equivalent to masking with 0xff).

But if we change the code to use signed fields by replacing unsigned with int above, then the codegen gets much noisier, with Clang producing:

process_pixel(pixel):
        shr     edi, 3
        movsx   eax, dil
        and     eax, 65532
        movsx   edi, ax
        jmp     process_green(short)@PLT

process_grb(int, int, int):
        shl     edi, 10
        movsx   edi, di
        sar     edi, 8
        jmp     process_green(short)@PLT

and GCC producing:

process_pixel(pixel):
        sal     edi, 5
        sar     di, 10
        sal     edi, 2
        movsx   edi, di
        jmp     process_green(short)

process_grb(int, int, int):
        sal     edi, 2
        sar     dil, 2
        movsx   di, dil
        sal     edi, 2
        movsx   edi, di
        jmp     process_green(short)

Neither of these is optimal. This function is genuinely trickier than the unsigned version: in particular, the sign-extension to short cannot simply be optimised away. However, there's no need to do multiple sign extension operations in each function.

Here are some notes on compiling such sequences to good code, with an implementation attached below. The intended application is the OCaml compiler, which is currently behind Clang and GCC for this sort of code. Despite that, it's particularly important for OCaml: all standard library integer types are signed (so sign-extensions are common), and the tagged integer representation means that some bit-packing occurs even in programs that just use plain int.

With these techniques, the two functions above can be compiled to:

process_pixel(pixel):
        shl     edi, 21
        sar     edi, 26
        shl     edi, 2
        jmp     process_green(short)

process_grb(int, int, int):
        shl     edi, 26
        sar     edi, 24
        jmp     process_green(short)

Bit-windowing functions

The approach here is based on bit-windowing functions W(x), where the function W can be expressed in normal form as a 6-tuple (i,j,k,l,s,T), describing the following operation:

      64                           j             i                   0
      +--------------------------------------------------------------+
   x: |                            |    data     |                   |
      +--------------------------------------------------------------+

      64        s          l             k                           0
      +--------------------------------------------------------------+
W(x): |    0    |   sign   |    data     |            T              |
      +--------------------------------------------------------------+

That is, bits [j:i] of the input are extracted and shifted to position [l:k] of the output, which is then sign-extended to position s and zero-extended from there to the end. The low bits (below k) are filled with the constant T.

To be a well-formed bit-window function, the tuple must satisfy the following (assuming a 64-bit machine):

  • 0 ≤ i < j ≤ 64
  • 0 ≤ k < l ≤ s ≤ 64
  • j - i = l - k
  • T < 2^k

Note in particular the strict <: the window must include at least one bit. The third constraint, that the input and output windows are the same length, means that the 6-tuple representation is slightly redundant: we always have l = k + j - i. However, it's much simpler to manipulate these tuples if we have l available explicitly.

The T component is not essential: more or less everything below still works fine if you assume T to always be zero. However, for the use-case of OCaml integer tagging (shift left by one and set the low bit), it's convenient to combine updating the low bits into bit-windowing functions.

To make things a bit more readable, I'll write the tuples in the syntax [j:i] → s/[l:k]+T.

The 6-tuple is a useful representation of bit-windowing functions:

  • They include basic shift and mask operations:

    • [64-n:0] → 64/[ 64:n]+0 is left shift by n
    • [ 64:n] → 64-n/[64-n:0]+0 is right logical shift by n
    • [ 64:n] → 64/[64-n:0]+0 is right arithmetic shift by n
    • [ 8:0] → 8/[ 8:0]+0 is 8- to 64-bit zero extension
    • [ 32:0] → 64/[ 32:0]+0 is 32- to 64-bit sign extension
    • [ j:i] → j/[ j:i]+0 is masking to keep only bits [j:i]
  • They also include some weirder x86-specific instructions

    • [16:8] → 32/[8:0]+0 is x86 movsx eax, ah: extract bits 8 through 16 of register rax, sign-extend to 32 bits, then zero-extend to 64 bits.

    • [32:10] → 32/[22:0]+0 is x86 sra eax, 10: do an arithmetic right shift of the low 32 bits of rax, zeroing the upper 32 bits.

  • The 6-tuple form is unique: two bit-windowing operations map the same inputs to the same outputs if and only if they have the same 6-tuple representation.

  • The 6-tuple form is closed (*ish) under composition, and compositions are easy (*ish) to compute.

The first (*ish) caveat above is that the composition of two bit-windowing functions is either another bit-windowing function or a constant. For instance, the composition of two nonoverlapping masks is the constant zero, and constant functions are not representable by a 6-tuple. The class of functions consisting of bit-windowing and constant maps is closed under composition, and the full composition function is as follows, whence the second (*ish):

let compose w1 w2 =
  let d1 = w1.k - w1.i in
  let d2 = w2.k - w2.i in
  let i = max w1.i (w2.i - d1) in
  let j = min w1.j (w2.j - d1) in
  let k = max (w1.k + d2) w2.k in
  let l = min (w1.l + d2) w2.l in
  assert (l - k = j - i);
  let s =
    if w2.j <= w1.s then w2.s (* w2 sign extension *)
    else w1.s + d2 (* w1 sign extension, w2 extends a zero *)
  in
  let t = eval w2 w1.t in
  if i < j then
    Window (make ~i ~j ~k ~l ~s ~t)
  else begin
    (* No overlap between windows, so probably constant *)
    assert (w2.j <= w1.k || w1.l <= w2.i);
    if w2.i < w1.s && w1.k < w2.j then
      (* Tricky case: Even though windows don't overlap, this isn't constant:
         it depends on some sign-extended bits from w1, so we extract just the sign bit *)
      Window (make ~i:(j-1) ~j ~k ~l:(k+1) ~s ~t)
    else
      Const t
  end

The intuition here is that you compose two windowing functions W₁ and W₂ by computing the overlap between the output window of W₁ and the input window of W₂. If the windows overlap, then the result maps the input of W₁ to the output of W₂ (suitably narrowed), and if they don't then the output is usually a constant. (The tricky case is when the windows don't overlap, but the output is non-constant because the window of W₂ extracts sign-extension bits of W₁)

(Someday I'll get around to convincing Z3 or Rocq that this composition function works. For now, the evidence is some paper scribbling, plus the fact that it's passed a few million random tests)

Simplifying expressions with window functions

The process_pixel function above is the composition of several shift functions:

[11:5]→32/[ 6:0]+0    ;; Extract 6-bit int bitfield to 32-bit signed int
[ 6:0]→32/[ 8:2]+0    ;; 32-bit left shift by 2
[16:0]→32/[16:0]+0    ;; Conversion to short (sign-extend to 32, zero-extend to 64)

Using the window composition function, this transforms to a single window operation:

[11:5]→32/[ 8:2]+0

The process_grb function is more complicated. It starts by constructing a pixel by converting its arguments to bitfields, then applies process_pixel, giving:

process_grb(g, r, b) = Wp(Wr(r) | Wg(g) | Wb(b))
  where
    Wp = [11:5]→32/[ 8: 2]+0 ;; as above
    Wr = [ 5:0]→16/[16:11]+0
    Wg = [ 6:0]→11/[11: 5]+0
    Wb = [ 5:0]→ 5/[ 5: 0]+0

To optimise this, we notice that every window function W satisfies W(a|b) = W(a) | W(b), so we can simplify to:

process_grb(g, r, b) = ((Wp⋅Wr)(r) | (Wp⋅Wg)(g) | (Wp⋅Wb)(b))

Since process_pixel extracts only the g component, the two windows Wp⋅Wr and Wp⋅Wb are the constant functions 0, and simplifying x|0 = 0 yields a single window [6:0]→32/[8:2]+0.

A few more transformations are possible for general window functions:

  • Bitwise AND distributes over windows just like OR, and XOR works if you zero one side's T component (so you don't end up XOR-ing it with itself).

  • An operation involving a window and a small constant W(x) op N can often be folded into a single window operation, if 0 ≤ T op N < 2^k and op is addition, subtraction, AND, OR, or XOR.

Generating reasonable code for window functions

The above means that any sequence of window functions can be collapsed to a single one, and distributed over bitwise or as needed, so what remains is to generate good code for a single window function.

If T = 0, then this can be done in two instructions if no sign extension is involved (that is, if l = s), or three instructions in general. This becomes 3 or 4 instructions, respectively, to add a nonzero T afterwards.

In the l = s case, you can use shift instruction to shift by k - i (that is, either left or right as needed to place bit i of input into position k of output), and then use a bitwise-AND instruction to select only the bits from k to l.

In the l < s case, where sign extension occurs, the first step is to shift left by 64 - j, putting the bit to be sign-e-xtended into the top bit, and then to perform an arithmetic right shift by 64 - l (to move the bits to the right location, sign extending all the way to the top bit), and finally a third instruction to select the required bits.

This simple strategy, plus some special cases to detect cases that can be implemented as a single instruction, generates reasonably short code - never more than a few instructions for an arbitrary sequence of shifting, masking and sign- or zero-extension.

Generating optimal code for window functions

Since window functions have an easily computable normal form, it's not much more work to generate optimal code. As always, "optimal" means "optimal with respect to a cost model", and the cost model I'm going to use here is:

  • Window functions are compiled to a sequence of x86_64 instructions, consisting of 32- and 64-bit left/right logical/arithmetic shifts, AND with constant masks, movzxd (zero-extension), and movsxd (sign-extension).

  • Each instruction has equal cost (all are single-cycle-latency instructions on recent machines), except for AND with a mask that does not fit in a 32-bit immediate (as these require an extra instruction to load the immediate). These large mask instructions have a cost greater than one simple instruction but less than two.

Each of the instructions considered can be expressed as a window function itself. This means that a sequence of such instructions can be decompiled back into a 6-tuple, by composing the window functions of each instruction. This makes it easy to establish correctness and optimality:

  • The compilation is correct iff decompiling the code generated for any window function yields the same window function.

  • The compilation is optimal iff decompiling an arbitrary sequence of instructions and compiling the resulting window function yields a sequence of instructions of at most the same cost.

The number of window functions is a bit too big to comfortably test exhaustively, so correctness is tested with a large number of random tests. Optimality can be tested exhaustively, though: the worst case code generated by an arbitrary window is two shifts and a mask, so we need only test code sequences cheaper than that, of which there are only so many.

The optimal code generator ends up about 80 lines of (admittedly tricky) code, written by hacking until the correctness and optimality checkers stop complaining. Each example of nonoptimal code that the optimality checker came up with (that is, code sequences that are cheaper than what the compiler produces) was generalised to a strategy for certain window functions, yielding a number of interesting tricks:

  • Some shift-and-mask operations are cheaper to do as a pair of shifts due to the size of the masks involved, especially if you're willing to mix-and-match 32-bit and 64-bit shifts (and benefit from the implicit zero-extension of the 32-bit ones).

  • When sign-extension is required, there are three different shift counts that should be considered for the arithmetic right shift. (The optimal generator tries all three and picks the cheapest)

    • Shift by 64-l: This is the basic strategy described above, which puts the data in the right place but needs masking.

    • Shift by i: This right-aligns the data. This likely requires a left shift to place it correctly, but may save a mask (sometimes more expensive than a shift).

    • Shift by l-s: This leaves the data in the wrong place, but generates the right number of sign bits, so no masking is needed to remove excess ones. Again, sometimes this can be finished off with a simple shift, saving a mask.

  • When sign-extension is required, sometimes it's better to move the sign bit into position 32 and use a 32-bit sign extension, instead of position 64.

  • When both logical shifting and masking are being done, masking should be done before a left shift but after a right shift. (This means that the set bits of the mask are in the lower bit positions both times, which maximises the chance of them fitting in a small immediate)

(** Bit-windowing functions and their composition operator *)
module Window = struct
type t = {
i : int;
j : int;
k : int;
l : int;
s : int;
t : Nativeint.t
}
let width = Sys.word_size
let make ~i ~j ~k ~l ~s ~t =
if 0 <= i && i < j && j <= width &&
0 <= k && k < l && l <= s && s <= width &&
(l - k) = (j - i) &&
Nativeint.shift_right_logical t k = 0n then
{ i; j; k; l; s; t }
else
raise (Invalid_argument "Window.make")
let to_string w =
Printf.sprintf "[%d:%d] -> %d/[%d:%d]+%nx" w.j w.i w.s w.l w.k w.t
let eval w n =
let open Nativeint in
let n = shift_right_logical n w.i in
let n = shift_left n (width - (w.j - w.i)) in
let n = shift_right n (w.s - w.l) in
let n = shift_right_logical n (width - w.s) in
let n = logor n w.t in
n
type window_or_const = Window of t | Const of Nativeint.t
let compose w1 w2 =
let d1 = w1.k - w1.i in
let d2 = w2.k - w2.i in
let i = max w1.i (w2.i - d1) in
let j = min w1.j (w2.j - d1) in
let k = max (w1.k + d2) w2.k in
let l = min (w1.l + d2) w2.l in
assert (l - k = j - i);
let s =
if w2.j <= w1.s then w2.s (* w2 sign extension *)
else w1.s + d2 (* w1 sign extension, w2 extends a zero *)
in
let t = eval w2 w1.t in
if i < j then
Window (make ~i ~j ~k ~l ~s ~t)
else begin
(* No overlap between windows, so probably constant *)
assert (w2.j <= w1.k || w1.l <= w2.i);
if w2.i < w1.s && w1.k < w2.j then
(* Tricky case: Even though windows don't overlap, this isn't constant:
it depends on some sign-extended bits from w1, so we extract just the sign bit *)
Window (make ~i:(j-1) ~j ~k ~l:(k+1) ~s ~t)
else
Const t
end
let eval' w n =
match w with
| Window w -> eval w n
| Const k -> k
let to_string' = function
| Window w -> to_string w
| Const k -> Printf.sprintf "K %nx" k
let compose' a b =
match a, b with
| Window a, Window b -> compose a b
| _, Const b -> Const b
| Const a, Window b -> Const (eval b a)
let id =
make ~i:0 ~j:width ~k:0 ~l:width ~s:width ~t:0n
let shift_left n =
make ~i:0 ~j:(width - n) ~k:n ~l:width ~s:width ~t:0n
let shift_right_logical n =
make ~i:n ~j:width ~k:0 ~l:(width-n) ~s:(width-n) ~t:0n
let shift_right_arith ~width n =
make ~i:n ~j:width ~k:0 ~l:(width-n) ~s:width ~t:0n
let zext n =
make ~i:0 ~j:n ~k:0 ~l:n ~s:n ~t:0n
let sext n =
make ~i:0 ~j:n ~k:0 ~l:n ~s:width ~t:0n
let mask ~lo ~hi =
make ~i:lo ~j:hi ~k:lo ~l:hi ~s:hi ~t:0n
end
(* Random testing of the window composition function *)
module Test_Window = struct
open Window
let random () =
(* random number a <= x <= b *)
let rand a b =
assert (a <= b);
a + Random.int (b - a + 1)
in
let i = rand 0 (width - 1) in
let j = rand (i+1) width in
let k = rand 0 (width - (j-i)) in
let l = k + (j - i) in
let s = rand l width in
let t = if k = 0 then 0n else Nativeint.shift_right_logical (Random.nativebits ()) (width - k) in
(* Printf.printf "%d %d %d %d %d %nd\n" i j k l s t; *)
make ~i ~j ~k ~l ~s ~t
let test () =
Random.self_init ();
let succ = ref 0 in
let iters = 1_000_000 in
for _i = 1 to iters do
try
let w1 = random () and w2 = random () in
let w = compose w1 w2 in
for _j = 1 to 100 do
let n = Random.nativebits () in
let p = eval w2 (eval w1 n) in
let q = eval' w n in
if p <> q then
Printf.printf "\nFailure:\n%nx\n%s -> %nx\n%s -> %nx\n%s -> %nx\n" n
(to_string w1) (eval w1 n) (to_string w2) p (to_string' w) q;
done;
incr succ
with Exit -> ()
done;
Printf.printf "%d/%d compose eval tests passed\n%!" !succ iters;
succ := 0;
for _i = 1 to iters do
let a = random () in
let b = random () in
let c = random () in
let a' = Window a and b' = Window b and c' = Window c in
assert (compose' a' (compose' b' c') = compose' (compose' a' b') c');
assert (compose a id = a');
assert (compose id a = a');
incr succ;
done;
Printf.printf "%d/%d compose monoidal tests passed\n%!" !succ iters;
end
(** Optimal code generation for x86_64 (see cost model & supported instructions) *)
module X86_64_Codegen = struct
type subreg = R64 | R32
let subreg_len = function R64 -> 64 | R32 -> 32
type op =
| Shl of subreg * int
| Shr of subreg * int
| Sar of subreg * int
| Movzx_32_64
| Movsx_32_64
| Mask of {lo: int; hi:int} (* AND with immediate *)
type x86reg = { reg64: string; reg32: string }
let reg_name r = function
| R64 -> r.reg64
| R32 -> r.reg32
let print_instruction ~reg =
let nativeint_mask bits =
if bits = Window.width then Nativeint.minus_one
else Nativeint.(pred (shift_left one bits)) in
let open Printf in function
| Shl (w, k) -> sprintf "shl %s, %d" (reg_name reg w) k
| Shr (w, k) -> sprintf "shr %s, %d" (reg_name reg w) k
| Sar (w, k) -> sprintf "sar %s, %d" (reg_name reg w) k
| Movzx_32_64 -> sprintf "mov %s, %s" (reg_name reg R32) (reg_name reg R32)
| Movsx_32_64 -> sprintf "movsx %s, %s" (reg_name reg R64) (reg_name reg R32)
| Mask {lo; hi} ->
(* FIXME out-of-range mask arguments break this *)
sprintf "and %s, %nxh" (reg_name reg R64) Nativeint.(logxor (nativeint_mask hi) (nativeint_mask lo))
let rax = { reg64 = "rax"; reg32 = "eax" }
let rdi = { reg64 = "rdi"; reg32 = "edi" }
let print_code ~reg code =
code |> List.iter (fun insn -> Printf.printf " %s\n%!" (print_instruction ~reg insn))
let cost ops =
let cost = function
| Shr _ | Shl _ | Sar _ | Movsx_32_64 -> 10
| Movzx_32_64 -> 9
| Mask {hi;lo} when hi <= 32 || hi = 64 && lo < 32 -> 11
| Mask _ -> 15
in
List.fold_left (fun c op -> c + cost op) 0 ops
let cheapest f =
let best_cost = ref max_int and best_code = ref [] in
f (fun code ->
let cost = cost code in
if cost < !best_cost then (best_cost := cost; best_code := code));
!best_code
(* Equivalent to code @ [Mask {hi=l; lo=k}], but tries to avoid
inserting the mask if it is implicitly already masked by [code] *)
let mask code ~k ~l =
let implicit_mask (hi, lo) op =
let rlen = subreg_len in
match op with
| Shr (r, n) -> max 0 (min hi (rlen r) - n), max 0 (lo - n)
| Sar (r, n) -> max 0 (min hi (rlen r)), max 0 (lo - n)
| Shl (r, n) -> min (rlen r) (hi + n), min (rlen r) (lo + n)
| _ -> 64, 0
in
let mask ~hi ~lo ~k ~l =
if l >= hi && lo >= k then []
else if l = 32 && k = 0 then [Movzx_32_64]
else [Mask {hi=l; lo=k}]
in
let rec insert_mask ~hi ~lo = function
| [] -> mask ~hi ~lo ~k ~l
| [Shl (r, n)] as ops ->
(* Cheaper to mask before shifting left: mask is smaller that way *)
mask ~hi:(min hi (subreg_len r - n)) ~lo ~k:(k-n) ~l:(l-n) @ ops
| op :: ops ->
let hi, lo = implicit_mask (hi, lo) op in
op :: insert_mask ~hi ~lo ops
in
insert_mask ~hi:64 ~lo:0 code
(* Code for windows with no sign extension (s = l) *)
let shift_mask code ~i ~j ~k ~l =
ignore (Window.make ~i ~j ~k ~l ~s:l ~t:0n);
if l = j then mask code ~k ~l
else cheapest @@ fun cand ->
(* Shift and mask *)
if l < j && j <= 32 then
cand (code @ [Shr (R32, j-l)] |> mask ~k ~l);
if l < j then
cand (code @ [Shr (R64, j-l)] |> mask ~k ~l);
if l > j && l <= 32 then
cand (code @ [Shl (R32, l-j)] |> mask ~k ~l);
if l > j then
cand (code @ [Shl (R64, l-j)] |> mask ~k ~l);
(* Two shifts. Sometimes better than shift & mask, if the mask is awkwardly large *)
if (i = 0 || k = 0) && j <= 32 && l <= 32 then
cand (code @ [Shl (R32, 32-j); Shr(R32, 32-l)]);
if (i = 0 || k = 0) then
cand (code @ [Shl (R64, 64-j); Shr(R64, 64-l)]);
if (j = 32 || l = 64) && j <= 32 then
cand (code @ [Shr (R32, i); Shl(R64, k)]);
if (j = 64 || l = 32) && l <= 32 then
cand (code @ [Shr (R64, i); Shl(R32, k)]);
if (j = 32 || l = 32) && j <= 32 && l <= 32 then
cand (code @ [Shr (R32, i); Shl(R32, k)]);
if (j = 64 || l = 64) then
cand (code @ [Shr (R64, i); Shl(R64, k)])
(* Code for windows with sign extension, and the sign bit in position 64 *)
let window_sext64 code ~i ~k ~l ~s =
ignore (Window.make ~i ~j:64 ~k ~l ~s ~t:0n);
let sx = s - l in
cheapest @@ fun cand ->
if i >= sx then
cand (code @ [Sar (R64, i)] |> shift_mask ~i:0 ~j:(s-k) ~k ~l:s);
cand (code @ [Sar (R64, sx)] |> shift_mask ~i:(i - sx) ~j:64 ~k ~l:s);
cand (code @ [Sar (R64, 64-l)] |> mask ~k ~l:s)
(* Code for windows with sign extension, and the sign bit in position 32 *)
let window_sext32 code ~i ~k ~l ~s =
ignore (Window.make ~i ~j:32 ~k ~l ~s ~t:0n);
let sx = s - l in
cheapest @@ fun cand ->
if i = 31 && k >= 31 then
cand (code @ [Movsx_32_64] |> mask ~k ~l:s);
if sx <= 32 then
cand (code @ [Movsx_32_64] |> shift_mask ~i ~j:(32+sx) ~k ~l:s);
if s - k <= 32 && i >= sx then
cand (code @ [Sar(R32, i)] |> shift_mask ~i:0 ~j:(s-k) ~k ~l:s);
if s - k <= 32 then
cand (code @ [Sar(R32, sx)] |> shift_mask ~i:(i-sx) ~j:32 ~k ~l:s);
if s <= 32 then
cand (code @ [Sar(R32, 32-l)] |> mask ~k ~l:s);
cand (code @ [Shl (R64, 32)] |> window_sext64 ~i:(i+32) ~k ~l ~s)
(* Code for general windows with T = 0 *)
let window_t0 code ~i ~j ~k ~l ~s =
ignore (Window.make ~i ~j ~k ~l ~s ~t:0n);
if s = l then shift_mask code ~i ~j ~k ~l
else cheapest @@ fun cand ->
(* Sign extension happens in positions 32 and 64, so move bit j to one of those *)
if j = 32 then
cand (code |> window_sext32 ~i ~k ~l ~s);
if j < 32 then
cand (code @ [Shl (R32, 32-j)] |> window_sext32 ~i:(i+(32-j)) ~k ~l ~s);
if j > 32 && j - i <= 32 then
cand (code @ [Shr (R64, j-32)] |> window_sext32 ~i:(i-(j-32)) ~k ~l ~s);
if j = 64 then
cand (code |> window_sext64 ~i ~k ~l ~s);
if j < 64 then
cand (code @ [Shl (R64, 64-j)] |> window_sext64 ~i:(i+(64-j)) ~k ~l ~s)
let codegen ({i;j;k;l;s;t} : Window.t) =
if t <> 0n then failwith "Nonzero w.t not supported yet";
window_t0 [] ~i ~j ~k ~l ~s
end
(** Random testing of correctness and exhaustive testing of optimality *)
module Test_X86_64_Codegen = struct
open X86_64_Codegen
let window_for_insn op =
let open Window in
let op32 w =
let (>>) = Window.compose' in
match Window (zext 32) >> Window w >> Window (zext 32) with
| Window w -> w
| Const _ -> assert false
in
match op with
| Shl (R64, k) -> shift_left k
| Shr (R64, k) -> shift_right_logical k
| Sar (R64, k) -> shift_right_arith ~width:64 k
| Shl (R32, k) -> op32 (shift_left k)
| Shr (R32, k) -> op32 (shift_right_logical k)
| Sar (R32, k) -> op32 (shift_right_arith ~width:32 k)
| Movzx_32_64 -> op32 id
| Movsx_32_64 -> sext 32
| Mask {lo; hi} -> mask ~lo ~hi
let decompile code =
let open Window in
List.fold_left
(fun w op -> compose' w (Window (window_for_insn op)))
(Window id)
code
(** Validate a code sequence for [w] by decompiling back to a window *)
let validate w code =
let open Window in
let w' = decompile code in
if w' = Window w then true
else begin
Printf.printf "Miscompilation of %s as %s:\n" (to_string w) (to_string' w');
print_code ~reg:rax code;
false
end
let check_codegen () =
let succ = ref 0 in
let iters = 1_000_000 in
for i = 1 to iters do
let w = Test_Window.random () in
let w = { w with t = 0n } in
let code = codegen w in
if validate w code then incr succ
done;
Printf.printf "%d/%d codegen tests passed\n%!" !succ iters
let sz_64 = List.init 64 Fun.id
let sz_32 = List.init 32 Fun.id
let simple_ops =
List.map (fun k -> Shl (R64, k)) sz_64 @
List.map (fun k -> Shr (R64, k)) sz_64 @
List.map (fun k -> Sar (R64, k)) sz_64 @
List.map (fun k -> Shl (R32, k)) sz_32 @
List.map (fun k -> Shr (R32, k)) sz_32 @
List.map (fun k -> Sar (R32, k)) sz_32 @
[Movzx_32_64;
Movsx_32_64]
let mask_ops =
List.concat_map (fun lo -> List.concat_map (fun hi -> if lo < hi then [Mask {lo;hi}] else []) sz_64) sz_64
let check_optimality () =
(* The worst case is two shifts and a mask, so to check optimality
we need to check everything cheaper than that *)
let check code =
match decompile code with
| Const _ -> true
| Window w ->
assert (w.t = 0n);
let code' = codegen w in
assert (validate w code');
if cost code' <= cost code then true
else begin
Printf.printf "Nonoptimal code for %s:\n" (Window.to_string w);
print_code ~reg:rax code;
Printf.printf "compiled as:\n";
print_code ~reg:rax code';
false;
end
in
let opt = ref 0 and tried = ref 0 in
(simple_ops @ mask_ops) |> List.iter (fun op ->
incr tried; if check [op] then incr opt);
Printf.printf "%d/%d single ops nonoptimal\n%!" (!tried - !opt) !tried;
opt := 0; tried := 0;
(simple_ops @ mask_ops) |> List.iter (fun op1 ->
(simple_ops @ mask_ops) |> List.iter (fun op2 ->
incr tried; if check [op1;op2] then incr opt));
Printf.printf "%d/%d double ops nonoptimal\n%!" (!tried - !opt) !tried;
opt := 0; tried := 0;
simple_ops |> List.iter (fun op1 ->
simple_ops |> List.iter (fun op2 ->
simple_ops |> List.iter (fun op3 ->
incr tried; if check [op1;op2;op3] then incr opt)));
Printf.printf "%d/%d triple ops nonoptimal\n%!" (!tried - !opt) !tried;
()
end
module Expr = struct
type term =
| Var of string
| Const of Nativeint.t
| Window of term * Window.t
| Bitwise of [`And|`Or] * term * term
let rec cost = function
| Var _ | Const _ -> 0
| Window _ -> 1
| Bitwise (_, a, b) -> 1 + cost a + cost b
let bitwise op a b =
match op, a, b with
| `Or, Const 0n, t -> t
| `Or, t, Const 0n -> t
| op, a, b -> Bitwise (op, a, b)
let rec window w = function
| Var _ as t -> Window (t, w)
| Const k -> Const (Window.eval w k)
| Window (t, w') ->
(match Window.compose w' w with
| Window w -> window w t
| Const k -> Const k)
| Bitwise (op, a, b) as orig ->
let temp = bitwise op (window w a) (window w b) in
if cost temp < cost orig + 1 then temp else Window (orig, w)
end
(* Example *)
let process_pixel t =
let open Expr in
let open Window in
t
|> window (make ~i:5 ~j:11 ~k:0 ~l:6 ~s:width ~t:0n)
|> window (shift_left 2)
|> window (sext 16)
|> window (zext 32)
let pack_rgb r g b =
let open Expr in
let open Window in
let r = window (zext 5) r in
let g = window (zext 6) g in
let b = window (zext 5) b in
bitwise `Or (window (shift_left 11) r)
(bitwise `Or (window (shift_left 5) g) b)
let () =
Random.self_init ();
Test_Window.test ();
Test_X86_64_Codegen.check_codegen ();
Test_X86_64_Codegen.check_optimality ();
Printf.printf "process_pixel(pixel):\n";
begin match process_pixel (Var "pixel") with
| Window (Var "pixel", w) ->
X86_64_Codegen.(print_code ~reg:rdi (codegen w))
| _ -> assert false
end;
Printf.printf "process_grb(g, r, b):\n";
begin match process_pixel (pack_rgb (Var "r") (Var "g") (Var "b")) with
| Window (Var "g", w) ->
X86_64_Codegen.(print_code ~reg:rdi (codegen w))
| _ -> assert false
end
1000000/1000000 compose eval tests passed
1000000/1000000 compose monoidal tests passed
1000000/1000000 codegen tests passed
0/2306 single ops nonoptimal
0/5317636 double ops nonoptimal
0/24389000 triple ops nonoptimal
process_pixel(pixel):
shl edi, 21
sar edi, 26
shl edi, 2
process_grb(g, r, b):
shl edi, 26
sar edi, 24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment