Skip to content

Instantly share code, notes, and snippets.

Created May 3, 2022 11:25
Show Gist options
  • Save klauspost/617e149f31f8967bc184f5a48c3834f4 to your computer and use it in GitHub Desktop.
Save klauspost/617e149f31f8967bc184f5a48c3834f4 to your computer and use it in GitHub Desktop.
//go:build amd64 && !appengine && !noasm && gc
// +build amd64,!appengine,!noasm,gc
// This file contains the specialisation of Decoder.Decompress4X
// that uses an asm implementation of its main loop.
package huff0
import (
// decompress4x_main_loop_amd64_9 is an x86 assembler implementation
// of Decompress4X when tablelog > 8.
func decompress4x_main_loop_amd64_9(ctx *decompress4xContext) uint8
func decompress4x_main_loop_amd64_10(ctx *decompress4xContext) uint8
func decompress4x_main_loop_amd64_11(ctx *decompress4xContext) uint8
// decompress4x_8b_loop_x86 is an x86 assembler implementation
// of Decompress4X when tablelog <= 8 which decodes 4 entries
// per loop.
func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext) uint8
// fallback8BitSize is the size where using Go version is faster.
const fallback8BitSize = 800
type decompress4xContext struct {
pbr0 *bitReaderShifted
pbr1 *bitReaderShifted
pbr2 *bitReaderShifted
pbr3 *bitReaderShifted
peekBits uint8
buf *byte
tbl *dEntrySingle
// Decompress4X will decompress a 4X encoded stream.
// The length of the supplied input must match the end of a block exactly.
// The *capacity* of the dst slice must match the destination size of
// the uncompressed data exactly.
func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) {
if len(d.dt.single) == 0 {
return nil, errors.New("no table loaded")
if len(src) < 6+(4*1) {
return nil, errors.New("input too small")
use8BitTables := d.actualTableLog <= 8
if cap(dst) < fallback8BitSize && use8BitTables {
return d.decompress4X8bit(dst, src)
var br [4]bitReaderShifted
// Decode "jump table"
start := 6
for i := 0; i < 3; i++ {
length := int(src[i*2]) | (int(src[i*2+1]) << 8)
if start+length >= len(src) {
return nil, errors.New("truncated input (or invalid offset)")
err := br[i].init(src[start : start+length])
if err != nil {
return nil, err
start += length
err := br[3].init(src[start:])
if err != nil {
return nil, err
// destination, offset to match first output
dstSize := cap(dst)
dst = dst[:dstSize]
out := dst
dstEvery := (dstSize + 3) / 4
const tlSize = 1 << tableLogMax
const tlMask = tlSize - 1
single := d.dt.single[:tlSize]
// Use temp table to avoid bound checks/append penalty.
buf := d.buffer()
var off uint8
var decoded int
const debug = false
ctx := decompress4xContext{
pbr0: &br[0],
pbr1: &br[1],
pbr2: &br[2],
pbr3: &br[3],
peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast()
buf: &buf[0][0],
tbl: &single[0],
// Decode 2 values from each decoder/loop.
const bufoff = 256
for {
if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 {
if use8BitTables {
off = decompress4x_8b_main_loop_amd64(&ctx)
} else {
switch d.actualTableLog {
case 9:
off = decompress4x_main_loop_amd64_9(&ctx)
case 10:
off = decompress4x_main_loop_amd64_10(&ctx)
case 11:
off = decompress4x_main_loop_amd64_11(&ctx)
//panic(fmt.Sprintf("unexpected tablelog: %d", d.actualTableLog))
if debug {
fmt.Print("DEBUG: ")
fmt.Printf("off=%d,", off)
for i := 0; i < 4; i++ {
fmt.Printf(" br[%d]={bitsRead=%d, value=%x, off=%d}",
i, br[i].bitsRead, br[i].value, br[i].off)
if off != 0 {
if bufoff > dstEvery {
return nil, errors.New("corruption detected: stream overrun 1")
copy(out, buf[0][:])
copy(out[dstEvery:], buf[1][:])
copy(out[dstEvery*2:], buf[2][:])
copy(out[dstEvery*3:], buf[3][:])
out = out[bufoff:]
decoded += bufoff * 4
// There must at least be 3 buffers left.
if len(out) < dstEvery*3 {
return nil, errors.New("corruption detected: stream overrun 2")
if off > 0 {
ioff := int(off)
if len(out) < dstEvery*3+ioff {
return nil, errors.New("corruption detected: stream overrun 3")
copy(out, buf[0][:off])
copy(out[dstEvery:], buf[1][:off])
copy(out[dstEvery*2:], buf[2][:off])
copy(out[dstEvery*3:], buf[3][:off])
decoded += int(off) * 4
out = out[off:]
// Decode remaining.
remainBytes := dstEvery - (decoded / 4)
for i := range br {
offset := dstEvery * i
endsAt := offset + remainBytes
if endsAt > len(out) {
endsAt = len(out)
br := &br[i]
bitsLeft := br.remaining()
for bitsLeft > 0 {
if offset >= endsAt {
return nil, errors.New("corruption detected: stream overrun 4")
// Read value and increment offset.
val := br.peekBitsFast(d.actualTableLog)
v := single[val&tlMask].entry
nBits := uint8(v)
bitsLeft -= uint(nBits)
out[offset] = uint8(v >> 8)
if offset != endsAt {
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt)
decoded += offset - dstEvery*i
err = br.close()
if err != nil {
return nil, err
if dstSize != decoded {
return nil, errors.New("corruption detected: short output block")
return dst, nil
package main
//go:generate go run gen.go -out ../decompress_amd64.s -pkg=huff0
import (
_ ""
. ""
. ""
func main() {
decompress := decompress4x{}
for i := 9; i <= 11; i++ {
decompress.nBits = i
decompress.n = i * 10
decompress.generateProcedure(fmt.Sprintf("decompress4x_main_loop_amd64_%d", i))
decompress8b := decompress4x{}
const buffoff = 256 // see decompress.go, we're using [4][256]byte table
type decompress4x struct {
n int
nBits int
bmi2 bool
func (d decompress4x) generateProcedure(name string) {
TEXT(name, 0, "func(ctx* decompress4xContext) uint8")
Doc(name+" is an x86 assembler implementation of Decompress4X when tablelog > 8.decodes a sequence", "")
off := GP64()
XORQ(off, off)
exhausted := GP64()
XORQ(exhausted.As64(), exhausted.As64()) // exhausted = false
buffer := GP64()
table := GP64()
br0 := GP64()
br1 := GP64()
br2 := GP64()
br3 := GP64()
Comment("Preload values")
ctx := Dereference(Param("ctx"))
Load(ctx.Field("buf"), buffer)
Load(ctx.Field("tbl"), table)
Load(ctx.Field("pbr0"), br0)
Load(ctx.Field("pbr1"), br1)
Load(ctx.Field("pbr2"), br2)
Load(ctx.Field("pbr3"), br3)
Comment("Main loop")
Label(name + "_main_loop")
d.decodeTwoValues(d.n+0, br0, table, buffer, off, exhausted)
d.decodeTwoValues(d.n+1, br1, table, buffer, off, exhausted)
d.decodeTwoValues(d.n+2, br2, table, buffer, off, exhausted)
d.decodeTwoValues(d.n+3, br3, table, buffer, off, exhausted)
ADDB(U8(2), off.As8()) // off += 2
TESTB(exhausted.As8(), exhausted.As8()) // any br[i].ofs < 4?
JNZ(LabelRef(name + "_done"))
CMPB(off.As8(), U8(0))
JNZ(LabelRef(name + "_main_loop"))
Label(name + "_done")
offsetComp, err := ReturnIndex(0).Resolve()
if err != nil {
MOVB(off.As8(), offsetComp.Addr)
// TODO [wmu]: I believe it's doable in avo, but can't figure out how to deal
// with arbitrary pointers to a given type
const bitReader_in = 0
const bitReader_off = bitReader_in + 3*8 // {ptr, len, cap}
const bitReader_value = bitReader_off + 8
const bitReader_bitsRead = bitReader_value + 8
func (d decompress4x) fillFast32(id, atLeast int, br, exhausted reg.GPVirtual) (brValue, brBitsRead reg.GPVirtual) {
Commentf("br%d.fillFast32()", id)
brValue = GP64()
brBitsRead = GP64()
MOVQ(Mem{Base: br, Disp: bitReader_value}, brValue)
MOVBQZX(Mem{Base: br, Disp: bitReader_bitsRead}, brBitsRead)
brOffset := GP64()
MOVQ(Mem{Base: br, Disp: bitReader_off}, brOffset)
// We must have at least 2 * max tablelog left
CMPQ(brBitsRead, U8(64-atLeast))
JBE(LabelRef("skip_fill" + strconv.Itoa(id)))
SUBQ(U8(32), brBitsRead) // b.bitsRead -= 32
SUBQ(U8(4), brOffset) // -= 4
// v :=[ :]
// v = v[:4]
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24)
tmp := GP64()
MOVQ(Mem{Base: br, Disp: bitReader_in}, tmp)
Comment("b.value |= uint64(low) << (b.bitsRead & 63)")
addr := Mem{Base: brOffset, Index: tmp.As64(), Scale: 1}
if d.bmi2 {
SHLXQ(brBitsRead, addr, tmp.As64()) // tmp = uint32([]) << (b.bitsRead & 63)
} else {
CX := reg.CL
MOVL(addr, tmp.As32()) // tmp = uint32([])
MOVQ(brBitsRead, CX.As64())
SHLQ(CX, tmp.As64())
ORQ(tmp.As64(), brValue)
Commentf("exhausted = exhausted || ( < 4)", id)
CMPQ(brOffset, U8(4))
tmp = GP64()
ORB(tmp.As8(), exhausted.As8())
MOVQ(brOffset, Mem{Base: br, Disp: bitReader_off})
Label("skip_fill" + strconv.Itoa(id))
// TODO: WIP, does not work.
// Fill, so there is at least 56 bits available.
// Would make it possible to decode all sizes with 4bytes/loop.
func (d decompress4x) fillFast56(id int, br, exhausted reg.GPVirtual) (brValue, brBitsRead reg.GPVirtual) {
Commentf("br%d.fillFast32()", id)
brBitsRead = GP64()
brOffset := GP64()
brPointer := GP64()
MOVQ(Mem{Base: br, Disp: bitReader_off}, brOffset)
MOVBQZX(Mem{Base: br, Disp: bitReader_bitsRead}, brBitsRead)
MOVQ(Mem{Base: br, Disp: bitReader_in}, brPointer)
off := GP64()
MOVQ(brBitsRead, off)
SHRQ(U8(3), off) // off = brBitsRead / 8
SUBQ(off, brOffset) // brOffset = brOffset - off
brValue = GP64()
MOVQ(Mem{Base: brPointer, Index: brOffset, Scale: 1}, brValue) // brValue = brPointer[brOffset]
ANDQ(U8(7), brBitsRead) // brBitsRead = brBitsRead & 7
// We must have at least 2 * max tablelog left
Commentf("exhausted = exhausted || ( < 4)", id)
CMPQ(brOffset, U8(4))
tmp := GP64()
ORB(tmp.As8(), exhausted.As8())
MOVQ(brOffset, Mem{Base: br, Disp: bitReader_off})
func (d decompress4x) decodeTwoValues(id int, br, table, buffer, off, exhausted reg.GPVirtual) {
brValue, brBitsRead := d.fillFast32(id, d.nBits*2, br, exhausted)
Commentf("val0 := br%d.peekTopBits(peekBits)", id)
CX := reg.CL
val := GP64()
if true {
MOVQ(U32(64-d.nBits), CX.As64())
MOVQ(brValue, val.As64())
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask
} else if false {
mask := GP64()
MOVQ(U32(64-d.nBits|(d.nBits<<8)), mask)
BEXTRQ(mask, brValue, val.As64())
} else {
MOVQ(brValue, val.As64())
SHRQ(U8(64-d.nBits), val.As64()) // val = (value >> peek_bits) & mask
Comment("v0 := table[val0&mask]")
v := reg.RDX
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, v.As16())
Commentf("br%d.advance(uint8(v0.entry)", id)
out := reg.RAX // Fixed since we need 8H
MOVB(v.As8H(), out.As8()) // BL = uint8(v0.entry >> 8)
MOVBQZX(v.As8(), CX.As64())
if d.bmi2 {
SHLXQ(v.As64(), brValue, brValue) // value <<= n
} else {
SHLQ(CX, brValue) // value <<= n
ADDQ(CX.As64(), brBitsRead) // bits_read += n
Commentf("val1 := br%d.peekTopBits(peekBits)", id)
if true {
// Fastest by far on AMD Zen2+
MOVQ(U32(64-d.nBits), CX.As64())
MOVQ(brValue, val.As64())
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask
} else if false {
// Requires BMI2, not much faster.
mask := GP64()
MOVQ(U32(64-d.nBits|(d.nBits<<8)), mask)
BEXTRQ(mask, brValue, val.As64())
} else {
// Slow on Zen2+
MOVQ(brValue, val.As64())
SHRQ(U8(64-d.nBits), val.As64()) // val = (value >> peek_bits) & mask
Comment("v1 := table[val1&mask]")
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, v.As16()) // tmp - v1
Commentf("br%d.advance(uint8(v1.entry))", id)
MOVB(v.As8H(), out.As8H()) // BH = uint8(v0.entry >> 8)
MOVBQZX(v.As8(), CX.As64())
if d.bmi2 {
SHLXQ(v.As64(), brValue, brValue) // value <<= n
} else {
SHLQ(CX, brValue) // value <<= n
ADDQ(CX.As64(), brBitsRead) // bits_read += n
Comment("these two writes get coalesced")
Comment("buf[stream][off] = uint8(v0.entry >> 8)")
Comment("buf[stream][off+1] = uint8(v1.entry >> 8)")
MOVW(out.As16(), Mem{Base: buffer, Index: off, Scale: 1, Disp: (id % 10) * buffoff})
Comment("update the bitrader reader structure")
MOVQ(brValue, Mem{Base: br, Disp: bitReader_value})
MOVB(brBitsRead.As8(), Mem{Base: br, Disp: bitReader_bitsRead})
func (d decompress4x) generateProcedure4x8bit(name string) {
TEXT(name, 0, "func(ctx* decompress4xContext) uint8")
Doc(name+" is an x86 assembler implementation of Decompress4X when tablelog > 8.decodes a sequence", "")
off := GP64()
XORQ(off, off)
exhausted := GP64()
XORQ(exhausted.As64(), exhausted.As64())
peekBits := GP64()
buffer := GP64()
table := GP64()
br0 := GP64()
br1 := GP64()
br2 := GP64()
br3 := GP64()
Comment("Preload values")
ctx := Dereference(Param("ctx"))
Load(ctx.Field("peekBits"), peekBits)
Load(ctx.Field("buf"), buffer)
Load(ctx.Field("tbl"), table)
Load(ctx.Field("pbr0"), br0)
Load(ctx.Field("pbr1"), br1)
Load(ctx.Field("pbr2"), br2)
Load(ctx.Field("pbr3"), br3)
Comment("Main loop")
d.decodeFourValues(0, br0, peekBits, table, buffer, off, exhausted)
d.decodeFourValues(1, br1, peekBits, table, buffer, off, exhausted)
d.decodeFourValues(2, br2, peekBits, table, buffer, off, exhausted)
d.decodeFourValues(3, br3, peekBits, table, buffer, off, exhausted)
ADDB(U8(4), off.As8()) // off += 4
TESTB(exhausted.As8(), exhausted.As8()) // any br[i].ofs < 4?
CMPB(off.As8(), U8(0))
offsetComp, err := ReturnIndex(0).Resolve()
if err != nil {
MOVB(off.As8(), offsetComp.Addr)
func (d decompress4x) decodeFourValues(id int, br, peekBits, table, buffer, off, exhausted reg.GPVirtual) {
brValue, brBitsRead := d.fillFast32(id+1000, 32, br, exhausted)
decompress := func(valID int, outByte reg.Register) {
CX := reg.CL
val := GP64()
Commentf("val%d := br%d.peekTopBits(peekBits)", valID, id)
MOVQ(brValue, val.As64())
MOVQ(peekBits, CX.As64())
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask
Commentf("v%d := table[val0&mask]", valID)
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, CX.As16())
Commentf("br%d.advance(uint8(v%d.entry)", id, valID)
MOVB(CX.As8H(), outByte) // BL = uint8(v0.entry >> 8)
MOVBQZX(CX.As8(), CX.As64())
if d.bmi2 {
SHLXQ(CX.As64(), brValue, brValue) // value <<= n
} else {
SHLQ(CX, brValue) // value <<= n
ADDQ(CX.As64(), brBitsRead) // bits_read += n
out := reg.RAX // Fixed since we need 8H
decompress(0, out.As8L())
decompress(1, out.As8H())
decompress(2, out.As8H())
decompress(3, out.As8L())
Comment("these four writes get coalesced")
Comment("buf[stream][off] = uint8(v0.entry >> 8)")
Comment("buf[stream][off+1] = uint8(v1.entry >> 8)")
Comment("buf[stream][off+2] = uint8(v2.entry >> 8)")
Comment("buf[stream][off+3] = uint8(v3.entry >> 8)")
MOVL(out.As32(), Mem{Base: buffer, Index: off, Scale: 1, Disp: id * buffoff})
Comment("update the bitreader reader structure")
MOVQ(brValue, Mem{Base: br, Disp: bitReader_value})
MOVB(brBitsRead.As8(), Mem{Base: br, Disp: bitReader_bitsRead})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment