Skip to content

Instantly share code, notes, and snippets.

@skochinsky
Forked from pervognsen/shift_dfa.md
Created August 2, 2021 12:41
Show Gist options
  • Save skochinsky/160cf9d80eae64831e24c081313243e1 to your computer and use it in GitHub Desktop.
Save skochinsky/160cf9d80eae64831e24c081313243e1 to your computer and use it in GitHub Desktop.
Shift-based DFAs

A traditional table-based DFA implementation looks like this:

uint8_t table[NUM_STATES][256]

uint8_t run(const uint8_t *start, const uint8_t *end, uint8_t state) {
    for (const uint8_t *s = start; s != end; s++)
        state = table[state][*s];
    return state;
}

On Skylake, the throughput should be around 7 cycles per input byte (if the table fits in L1 cache). The main problem is that you can't issue the load to compute the next state until you've finished computing the current state, even on an out-of-order machine. There's actually a smaller, hidden problem too: when you compile the implicit address calculation involved in the table lookup you get something like this:

state = load(table + 256 * state + *s)

As a result, the multiplication by 256 (shift by 8) is also added to the loop-carried dependency chain for the state, resulting in an extra cycle of latency. You can fix that by just flipping the table:

uint8_t table[256][NUM_STATES];

uint8_t run(const uint8_t *start, const uint8_t *end, uint8_t state) {
    for (const uint8_t *s = start; s != end; s++)
        state = table[*s][state];
    return state;
}

Now the compiled table lookup looks like this:

state = load(table + NUM_STATES * *s + state)

That apparently trivial change should yield 16% better throughput. Incidentally, it can also help with memory locality if only a subset of the input bytes are seen in practice (e.g. 0..128 for ASCII). Even though that isn't the topic of this write-up, I wanted to go through that example since (1) using the "wrong" table layout is a common mistake (and I make it all the time despite being on guard) and (2) it illustrates how the loop-carried dependency chain latency for the state is what controls the performance of a DFA loop like this.

In order to make this go faster, let's think of the state transition table as an array of rows where indexing with an input byte yields the row that specifies all possible state transitions. Instead of doing our table lookup as a single-level load (with an address calculation that combines the byte and current state) we're going to do a two-level lookup where the first-level lookup is a load from memory that only depends on the input byte and gives us the row, and the second-level lookup picks out the right column within the row based on the current state. The key to making this faster is that the second-level lookup should not involve memory latency: it has to be done as a fast ALU operation, ideally in one cycle. This suggests the following approach: each row will be encoded as a 64-bit word, and we pick out the state-dependent column within the row by shifting and masking:

uint64_t table[256];

uint8_t run(const uint8_t *start, const uint8_t *end, uint8_t state) {
    for (const uint8_t *s = start; s != end; s++) {
        uint64_t row = table[*s];
        state = (row >> (state * BITS_PER_STATE)) & ((1 << BITS_PER_STATE) - 1);
    }
    return state;
}

The shift has 1c latency, the multiply has 3c latency or 1c if BITS_PER_STATE is a power of two, and the masking has 1c latency. So the latency on the critical path is only 3c/5c, which is already much better. But now for the final two tricks. First, instead of encoding the state sequentially as 0, 1, 2, etc, we are going to pre-multiply by BITS_PER_STATE:

uint64_t table[256];

uint8_t run(const uint8_t *start, const uint8_t *end, uint8_t state) {
    for (const uint8_t *s = start; s != end; s++) {
        uint64_t row = table[*s];
        state = (row >> state) & ((1 << BITS_PER_STATE) - 1);
    }
    return state;
}

And finally we pick BITS_PER_STATE = 6 so that (1 << BITS_PER_STATE) - 1 = 63. Most instruction sets already interpret 64-bit shift amounts mod 64, so the masking is now automatic at the machine code level:

uint64_t table[256];

uint64_t run(const uint8_t *start, const uint8_t *end, uint64_t state) {
    for (const uint8_t *s = start; s != end; s++) {
        uint64_t row = table[*s];
        state = row >> (state & 63);
    }
    return state & 63;
}

I also propagated the & 63 to the shift operand to help the compiler's instruction selection do what we want. The compiler now generates this for what was the loop's critical path:

shrx rax, qword ptr [8*rcx + table], rax

That is, 1c latency on the critical path for the state in register rax.

Now that latency is so low, we should also examine other potential bottlenecks that might prevent us reaching 1 byte/cycle. We'd need 2 scalar loads per cycle, for the input byte and for the row. That happens to be the limit for Skylake, which is a good stand-in for any consumer PC sold in the last 6-7 years; newer AMD processors (starting with Zen 3) and Apple's M1 processor can do 3 scalar loads per cycle. So we should just scrape by with enough load capacity, even on Skylake. Another bottleneck to consider is that the string pointer increment imposes a 1 cycle/byte latency limit. Also, modern x86 CPUs have a limit of one taken branch per cycle or even one taken branch every two cycles. For a 4-wide issue processor like Skylake, there might also be too many instructions per input byte. So four other things are just on the cusp of being too tight, which usually spells doom for sustained pipeline utilization.

Fortunately the last three bottlenecks are alleviated by just unrolling the loop. Clang will unroll aggressively on its own but GCC refuses to unroll unless forced, so you'll want to do the unrolling manually to have reliable performance across compilers. Because we need near-perfect pipeline utilization of critical resources (2 loads per cycle, etc) I kept seeing marginal gains even up to absurdly high unroll factors; just keep in mind that was in a micro-benchmark, and moderate unrolling is better suited for real-world use.

The main restriction with a shift-based DFA is that each row is 64 bits and hence with BITS_PER_STATE = 6 you can only accommodate up to 10 states. But 10 states is enough for a lot of useful tasks, including UTF-8 validation, simple needle-in-a-haystack pattern matching, skip lexing, etc. For many of these applications it's helpful to use an absorbing "state of interest" (e.g. an absorbing error state for UTF-8 validation) so you can hoist state-based branches out of the innermost loop.

Anyway, that's how you make a table-driven DFA go fast with simple, portable code. On my laptop's Ryzen 5900HX (Zen 3) processor, this runs at a throughput of 1 byte/cycle as predicted (4.5-4.6 GB/s with 4.6 GHz max clock), 4-5 times faster than a traditional table-based DFA. Not too bad for executing an arbitrary 10-state DFA!

Note that you can scale this idea beyond 10 states: make a row 2x64 bits, load both words, and mux between them. The BITS_PER_STATE is then 7 where the lower 6 bits are the shift amount and the 7th bit is the word selector, so you can fit 9 states per word and 18 states per row. But you leave the 1 byte/cycle sweet spot as soon as you add anything to the state-dependent logic (and the other bottlenecks will also assert themselves if you add much of anything). Concretely, just the TEST + CMOV you need for 2x64-bit rows adds 2 cycles to the critical path latency, so throughput is 1/3rd of a pure shift-based DFA. That said, the performance gap to a traditional table-based DFA is large enough that you can fit a few loads and muxes like this and still come out ahead, so your take-away shouldn't be that this general approach is totally useless beyond 10 states. On Skylake where an L1 load with a complex addressing mode is 6 cycles and [rax + rbx] is considered complex, 3 cycles is still 2x faster. But on Zen 3 where an L1 load with a simple addressing mode is only 4 cycles and [rax + rbx] is considered simple, 3 cycles is not very compelling given the limitations.

Background: To the best of my knowledge the single-cycle shift-based DFA is original work [1], but presented in this manner I hope it seems obvious to you in hindsight. The idea of separating the input-dependent and state-dependent latencies in the state update is known to SIMD hackers (e.g. Geoff Langdale used it in Hyperscan), and that was the main inspiration: you load the row as a vector and then use PSHUFB, VTBL and similar instructions to do the column selection. The 1 byte/cycle version of a PSHUFB-based DFA is limited to 16 states. Since PSHUFB/VTBL is vectorized it also allows for other tricks based on parallel-prefix circuits you can't do with the shift-based scalar technique, but you run into the same max loads/cycle wall. The input bytes can be loaded in groups of 8 using a single load and then bitwise extracted from there. The extra overhead of the bitwise extraction was too high on x86 to be a net win when I tried it, but we got to 2.3 bytes/cycle with a parallel-prefix VTBL-based DFA (my initial version was 1.5 bytes/cycle and @dougallj and I eventually pushed it to 2.3) on an M1 Firestorm perf core in the new Macs (and since VTBL is high latency compared to PSHUFB you need the latency hiding from parallel-prefix windowing just to reach 1 byte/cycle): https://twitter.com/pervognsen/status/1365170848215142400

[1] I first tweeted about it here: https://twitter.com/pervognsen/status/1364164150843252736. If you're interested there's a discussion in the tweet thread where some of the initial refinement happened and a bunch of extensions were explored. Travis Downs pointed out that you can do 16 states using BITS_PER_STATE = 4 if you're okay with 2 cycles/byte: https://twitter.com/trav_downs/status/1366524409196912641. This is strictly better than the 3 cycles/byte 2-to-1 row mux except in the edge case where you need 17 or 18 states and 16 won't do. And at 2 cycles/byte it's a big win even on machines with lower L1 latency like Zen 3. Another sub-thread discusses the fact that with the single-cycle version you can support a restricted 11th state. You only have 4 bits left after 10 * 6 bits for the 10 states, so you can only support outgoing edges from the 11th state to three other states (assuming you're using a logical right shift) with the premultiplied values 0, 6, 12. Since you have the freedom to permute the state numbers, you can make the 11th state be any state that only connects to 3 other states. And in some special cases you can squeeze extra juice from the algorithm by using an arithmetic right shift to perform bit smearing. Go through the linked twitter threads for more ideas/details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment