Created
May 3, 2022 11:25
-
-
Save klauspost/617e149f31f8967bc184f5a48c3834f4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//go:build amd64 && !appengine && !noasm && gc | |
// +build amd64,!appengine,!noasm,gc | |
// This file contains the specialisation of Decoder.Decompress4X | |
// that uses an asm implementation of its main loop. | |
package huff0 | |
import ( | |
"errors" | |
"fmt" | |
) | |
// decompress4x_main_loop_amd64_9 is an x86 assembler implementation | |
// of Decompress4X when tablelog > 8. | |
//go:noescape | |
func decompress4x_main_loop_amd64_9(ctx *decompress4xContext) uint8 | |
//go:noescape | |
func decompress4x_main_loop_amd64_10(ctx *decompress4xContext) uint8 | |
//go:noescape | |
func decompress4x_main_loop_amd64_11(ctx *decompress4xContext) uint8 | |
// decompress4x_8b_loop_x86 is an x86 assembler implementation | |
// of Decompress4X when tablelog <= 8 which decodes 4 entries | |
// per loop. | |
//go:noescape | |
func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext) uint8 | |
// fallback8BitSize is the size where using Go version is faster. | |
const fallback8BitSize = 800 | |
type decompress4xContext struct { | |
pbr0 *bitReaderShifted | |
pbr1 *bitReaderShifted | |
pbr2 *bitReaderShifted | |
pbr3 *bitReaderShifted | |
peekBits uint8 | |
buf *byte | |
tbl *dEntrySingle | |
} | |
// Decompress4X will decompress a 4X encoded stream. | |
// The length of the supplied input must match the end of a block exactly. | |
// The *capacity* of the dst slice must match the destination size of | |
// the uncompressed data exactly. | |
func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { | |
if len(d.dt.single) == 0 { | |
return nil, errors.New("no table loaded") | |
} | |
if len(src) < 6+(4*1) { | |
return nil, errors.New("input too small") | |
} | |
use8BitTables := d.actualTableLog <= 8 | |
if cap(dst) < fallback8BitSize && use8BitTables { | |
return d.decompress4X8bit(dst, src) | |
} | |
var br [4]bitReaderShifted | |
// Decode "jump table" | |
start := 6 | |
for i := 0; i < 3; i++ { | |
length := int(src[i*2]) | (int(src[i*2+1]) << 8) | |
if start+length >= len(src) { | |
return nil, errors.New("truncated input (or invalid offset)") | |
} | |
err := br[i].init(src[start : start+length]) | |
if err != nil { | |
return nil, err | |
} | |
start += length | |
} | |
err := br[3].init(src[start:]) | |
if err != nil { | |
return nil, err | |
} | |
// destination, offset to match first output | |
dstSize := cap(dst) | |
dst = dst[:dstSize] | |
out := dst | |
dstEvery := (dstSize + 3) / 4 | |
const tlSize = 1 << tableLogMax | |
const tlMask = tlSize - 1 | |
single := d.dt.single[:tlSize] | |
// Use temp table to avoid bound checks/append penalty. | |
buf := d.buffer() | |
var off uint8 | |
var decoded int | |
const debug = false | |
ctx := decompress4xContext{ | |
pbr0: &br[0], | |
pbr1: &br[1], | |
pbr2: &br[2], | |
pbr3: &br[3], | |
peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast() | |
buf: &buf[0][0], | |
tbl: &single[0], | |
} | |
// Decode 2 values from each decoder/loop. | |
const bufoff = 256 | |
for { | |
if br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4 { | |
break | |
} | |
if use8BitTables { | |
off = decompress4x_8b_main_loop_amd64(&ctx) | |
} else { | |
switch d.actualTableLog { | |
case 9: | |
off = decompress4x_main_loop_amd64_9(&ctx) | |
case 10: | |
off = decompress4x_main_loop_amd64_10(&ctx) | |
case 11: | |
off = decompress4x_main_loop_amd64_11(&ctx) | |
default: | |
//panic(fmt.Sprintf("unexpected tablelog: %d", d.actualTableLog)) | |
} | |
} | |
if debug { | |
fmt.Print("DEBUG: ") | |
fmt.Printf("off=%d,", off) | |
for i := 0; i < 4; i++ { | |
fmt.Printf(" br[%d]={bitsRead=%d, value=%x, off=%d}", | |
i, br[i].bitsRead, br[i].value, br[i].off) | |
} | |
fmt.Println("") | |
} | |
if off != 0 { | |
break | |
} | |
if bufoff > dstEvery { | |
d.bufs.Put(buf) | |
return nil, errors.New("corruption detected: stream overrun 1") | |
} | |
copy(out, buf[0][:]) | |
copy(out[dstEvery:], buf[1][:]) | |
copy(out[dstEvery*2:], buf[2][:]) | |
copy(out[dstEvery*3:], buf[3][:]) | |
out = out[bufoff:] | |
decoded += bufoff * 4 | |
// There must at least be 3 buffers left. | |
if len(out) < dstEvery*3 { | |
d.bufs.Put(buf) | |
return nil, errors.New("corruption detected: stream overrun 2") | |
} | |
} | |
if off > 0 { | |
ioff := int(off) | |
if len(out) < dstEvery*3+ioff { | |
d.bufs.Put(buf) | |
return nil, errors.New("corruption detected: stream overrun 3") | |
} | |
copy(out, buf[0][:off]) | |
copy(out[dstEvery:], buf[1][:off]) | |
copy(out[dstEvery*2:], buf[2][:off]) | |
copy(out[dstEvery*3:], buf[3][:off]) | |
decoded += int(off) * 4 | |
out = out[off:] | |
} | |
// Decode remaining. | |
remainBytes := dstEvery - (decoded / 4) | |
for i := range br { | |
offset := dstEvery * i | |
endsAt := offset + remainBytes | |
if endsAt > len(out) { | |
endsAt = len(out) | |
} | |
br := &br[i] | |
bitsLeft := br.remaining() | |
for bitsLeft > 0 { | |
br.fill() | |
if offset >= endsAt { | |
d.bufs.Put(buf) | |
return nil, errors.New("corruption detected: stream overrun 4") | |
} | |
// Read value and increment offset. | |
val := br.peekBitsFast(d.actualTableLog) | |
v := single[val&tlMask].entry | |
nBits := uint8(v) | |
br.advance(nBits) | |
bitsLeft -= uint(nBits) | |
out[offset] = uint8(v >> 8) | |
offset++ | |
} | |
if offset != endsAt { | |
d.bufs.Put(buf) | |
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt) | |
} | |
decoded += offset - dstEvery*i | |
err = br.close() | |
if err != nil { | |
return nil, err | |
} | |
} | |
d.bufs.Put(buf) | |
if dstSize != decoded { | |
return nil, errors.New("corruption detected: short output block") | |
} | |
return dst, nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
//go:generate go run gen.go -out ../decompress_amd64.s -pkg=huff0 | |
import ( | |
"flag" | |
"fmt" | |
"strconv" | |
_ "github.com/klauspost/compress" | |
. "github.com/mmcloughlin/avo/build" | |
"github.com/mmcloughlin/avo/buildtags" | |
. "github.com/mmcloughlin/avo/operand" | |
"github.com/mmcloughlin/avo/reg" | |
) | |
func main() { | |
flag.Parse() | |
Constraint(buildtags.Not("appengine").ToConstraint()) | |
Constraint(buildtags.Not("noasm").ToConstraint()) | |
Constraint(buildtags.Term("gc").ToConstraint()) | |
Constraint(buildtags.Not("noasm").ToConstraint()) | |
decompress := decompress4x{} | |
for i := 9; i <= 11; i++ { | |
decompress.nBits = i | |
decompress.n = i * 10 | |
decompress.generateProcedure(fmt.Sprintf("decompress4x_main_loop_amd64_%d", i)) | |
} | |
decompress8b := decompress4x{} | |
decompress8b.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64") | |
Generate() | |
} | |
const buffoff = 256 // see decompress.go, we're using [4][256]byte table | |
type decompress4x struct { | |
n int | |
nBits int | |
bmi2 bool | |
} | |
func (d decompress4x) generateProcedure(name string) { | |
Package("github.com/klauspost/compress/huff0") | |
TEXT(name, 0, "func(ctx* decompress4xContext) uint8") | |
Doc(name+" is an x86 assembler implementation of Decompress4X when tablelog > 8.decodes a sequence", "") | |
Pragma("noescape") | |
off := GP64() | |
XORQ(off, off) | |
exhausted := GP64() | |
XORQ(exhausted.As64(), exhausted.As64()) // exhausted = false | |
buffer := GP64() | |
table := GP64() | |
br0 := GP64() | |
br1 := GP64() | |
br2 := GP64() | |
br3 := GP64() | |
Comment("Preload values") | |
{ | |
ctx := Dereference(Param("ctx")) | |
Load(ctx.Field("buf"), buffer) | |
Load(ctx.Field("tbl"), table) | |
Load(ctx.Field("pbr0"), br0) | |
Load(ctx.Field("pbr1"), br1) | |
Load(ctx.Field("pbr2"), br2) | |
Load(ctx.Field("pbr3"), br3) | |
} | |
Comment("Main loop") | |
Label(name + "_main_loop") | |
d.decodeTwoValues(d.n+0, br0, table, buffer, off, exhausted) | |
d.decodeTwoValues(d.n+1, br1, table, buffer, off, exhausted) | |
d.decodeTwoValues(d.n+2, br2, table, buffer, off, exhausted) | |
d.decodeTwoValues(d.n+3, br3, table, buffer, off, exhausted) | |
ADDB(U8(2), off.As8()) // off += 2 | |
TESTB(exhausted.As8(), exhausted.As8()) // any br[i].ofs < 4? | |
JNZ(LabelRef(name + "_done")) | |
CMPB(off.As8(), U8(0)) | |
JNZ(LabelRef(name + "_main_loop")) | |
Label(name + "_done") | |
offsetComp, err := ReturnIndex(0).Resolve() | |
if err != nil { | |
panic(err) | |
} | |
MOVB(off.As8(), offsetComp.Addr) | |
RET() | |
} | |
// TODO [wmu]: I believe it's doable in avo, but can't figure out how to deal | |
// with arbitrary pointers to a given type | |
const bitReader_in = 0 | |
const bitReader_off = bitReader_in + 3*8 // {ptr, len, cap} | |
const bitReader_value = bitReader_off + 8 | |
const bitReader_bitsRead = bitReader_value + 8 | |
func (d decompress4x) fillFast32(id, atLeast int, br, exhausted reg.GPVirtual) (brValue, brBitsRead reg.GPVirtual) { | |
Commentf("br%d.fillFast32()", id) | |
brValue = GP64() | |
brBitsRead = GP64() | |
MOVQ(Mem{Base: br, Disp: bitReader_value}, brValue) | |
MOVBQZX(Mem{Base: br, Disp: bitReader_bitsRead}, brBitsRead) | |
brOffset := GP64() | |
MOVQ(Mem{Base: br, Disp: bitReader_off}, brOffset) | |
// We must have at least 2 * max tablelog left | |
CMPQ(brBitsRead, U8(64-atLeast)) | |
JBE(LabelRef("skip_fill" + strconv.Itoa(id))) | |
SUBQ(U8(32), brBitsRead) // b.bitsRead -= 32 | |
SUBQ(U8(4), brOffset) // b.off -= 4 | |
// v := b.in[b.off-4 : b.off] | |
// v = v[:4] | |
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) | |
tmp := GP64() | |
MOVQ(Mem{Base: br, Disp: bitReader_in}, tmp) | |
Comment("b.value |= uint64(low) << (b.bitsRead & 63)") | |
addr := Mem{Base: brOffset, Index: tmp.As64(), Scale: 1} | |
if d.bmi2 { | |
SHLXQ(brBitsRead, addr, tmp.As64()) // tmp = uint32(b.in[b.off:b.off+4]) << (b.bitsRead & 63) | |
} else { | |
CX := reg.CL | |
MOVL(addr, tmp.As32()) // tmp = uint32(b.in[b.off:b.off+4]) | |
MOVQ(brBitsRead, CX.As64()) | |
SHLQ(CX, tmp.As64()) | |
} | |
ORQ(tmp.As64(), brValue) | |
{ | |
Commentf("exhausted = exhausted || (br%d.off < 4)", id) | |
CMPQ(brOffset, U8(4)) | |
tmp = GP64() | |
SETLT(tmp.As8()) | |
ORB(tmp.As8(), exhausted.As8()) | |
} | |
MOVQ(brOffset, Mem{Base: br, Disp: bitReader_off}) | |
Label("skip_fill" + strconv.Itoa(id)) | |
return | |
} | |
// TODO: WIP, does not work. | |
// Fill, so there is at least 56 bits available. | |
// Would make it possible to decode all sizes with 4bytes/loop. | |
func (d decompress4x) fillFast56(id int, br, exhausted reg.GPVirtual) (brValue, brBitsRead reg.GPVirtual) { | |
Commentf("br%d.fillFast32()", id) | |
brBitsRead = GP64() | |
brOffset := GP64() | |
brPointer := GP64() | |
MOVQ(Mem{Base: br, Disp: bitReader_off}, brOffset) | |
MOVBQZX(Mem{Base: br, Disp: bitReader_bitsRead}, brBitsRead) | |
MOVQ(Mem{Base: br, Disp: bitReader_in}, brPointer) | |
off := GP64() | |
MOVQ(brBitsRead, off) | |
SHRQ(U8(3), off) // off = brBitsRead / 8 | |
SUBQ(off, brOffset) // brOffset = brOffset - off | |
brValue = GP64() | |
MOVQ(Mem{Base: brPointer, Index: brOffset, Scale: 1}, brValue) // brValue = brPointer[brOffset] | |
ANDQ(U8(7), brBitsRead) // brBitsRead = brBitsRead & 7 | |
// We must have at least 2 * max tablelog left | |
{ | |
Commentf("exhausted = exhausted || (br%d.off < 4)", id) | |
CMPQ(brOffset, U8(4)) | |
tmp := GP64() | |
SETLT(tmp.As8()) | |
ORB(tmp.As8(), exhausted.As8()) | |
} | |
MOVQ(brOffset, Mem{Base: br, Disp: bitReader_off}) | |
return | |
} | |
func (d decompress4x) decodeTwoValues(id int, br, table, buffer, off, exhausted reg.GPVirtual) { | |
brValue, brBitsRead := d.fillFast32(id, d.nBits*2, br, exhausted) | |
Commentf("val0 := br%d.peekTopBits(peekBits)", id) | |
CX := reg.CL | |
val := GP64() | |
if true { | |
MOVQ(U32(64-d.nBits), CX.As64()) | |
MOVQ(brValue, val.As64()) | |
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask | |
} else if false { | |
mask := GP64() | |
MOVQ(U32(64-d.nBits|(d.nBits<<8)), mask) | |
BEXTRQ(mask, brValue, val.As64()) | |
} else { | |
MOVQ(brValue, val.As64()) | |
SHRQ(U8(64-d.nBits), val.As64()) // val = (value >> peek_bits) & mask | |
} | |
Comment("v0 := table[val0&mask]") | |
v := reg.RDX | |
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, v.As16()) | |
Commentf("br%d.advance(uint8(v0.entry)", id) | |
out := reg.RAX // Fixed since we need 8H | |
MOVB(v.As8H(), out.As8()) // BL = uint8(v0.entry >> 8) | |
MOVBQZX(v.As8(), CX.As64()) | |
if d.bmi2 { | |
SHLXQ(v.As64(), brValue, brValue) // value <<= n | |
} else { | |
SHLQ(CX, brValue) // value <<= n | |
} | |
ADDQ(CX.As64(), brBitsRead) // bits_read += n | |
Commentf("val1 := br%d.peekTopBits(peekBits)", id) | |
if true { | |
// Fastest by far on AMD Zen2+ | |
MOVQ(U32(64-d.nBits), CX.As64()) | |
MOVQ(brValue, val.As64()) | |
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask | |
} else if false { | |
// Requires BMI2, not much faster. | |
mask := GP64() | |
MOVQ(U32(64-d.nBits|(d.nBits<<8)), mask) | |
BEXTRQ(mask, brValue, val.As64()) | |
} else { | |
// Slow on Zen2+ | |
MOVQ(brValue, val.As64()) | |
SHRQ(U8(64-d.nBits), val.As64()) // val = (value >> peek_bits) & mask | |
} | |
Comment("v1 := table[val1&mask]") | |
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, v.As16()) // tmp - v1 | |
Commentf("br%d.advance(uint8(v1.entry))", id) | |
MOVB(v.As8H(), out.As8H()) // BH = uint8(v0.entry >> 8) | |
MOVBQZX(v.As8(), CX.As64()) | |
if d.bmi2 { | |
SHLXQ(v.As64(), brValue, brValue) // value <<= n | |
} else { | |
SHLQ(CX, brValue) // value <<= n | |
} | |
ADDQ(CX.As64(), brBitsRead) // bits_read += n | |
Comment("these two writes get coalesced") | |
Comment("buf[stream][off] = uint8(v0.entry >> 8)") | |
Comment("buf[stream][off+1] = uint8(v1.entry >> 8)") | |
MOVW(out.As16(), Mem{Base: buffer, Index: off, Scale: 1, Disp: (id % 10) * buffoff}) | |
Comment("update the bitrader reader structure") | |
MOVQ(brValue, Mem{Base: br, Disp: bitReader_value}) | |
MOVB(brBitsRead.As8(), Mem{Base: br, Disp: bitReader_bitsRead}) | |
} | |
func (d decompress4x) generateProcedure4x8bit(name string) { | |
Package("github.com/klauspost/compress/huff0") | |
TEXT(name, 0, "func(ctx* decompress4xContext) uint8") | |
Doc(name+" is an x86 assembler implementation of Decompress4X when tablelog > 8.decodes a sequence", "") | |
Pragma("noescape") | |
off := GP64() | |
XORQ(off, off) | |
exhausted := GP64() | |
XORQ(exhausted.As64(), exhausted.As64()) | |
peekBits := GP64() | |
buffer := GP64() | |
table := GP64() | |
br0 := GP64() | |
br1 := GP64() | |
br2 := GP64() | |
br3 := GP64() | |
Comment("Preload values") | |
{ | |
ctx := Dereference(Param("ctx")) | |
Load(ctx.Field("peekBits"), peekBits) | |
Load(ctx.Field("buf"), buffer) | |
Load(ctx.Field("tbl"), table) | |
Load(ctx.Field("pbr0"), br0) | |
Load(ctx.Field("pbr1"), br1) | |
Load(ctx.Field("pbr2"), br2) | |
Load(ctx.Field("pbr3"), br3) | |
} | |
Comment("Main loop") | |
Label("main_loop") | |
d.decodeFourValues(0, br0, peekBits, table, buffer, off, exhausted) | |
d.decodeFourValues(1, br1, peekBits, table, buffer, off, exhausted) | |
d.decodeFourValues(2, br2, peekBits, table, buffer, off, exhausted) | |
d.decodeFourValues(3, br3, peekBits, table, buffer, off, exhausted) | |
ADDB(U8(4), off.As8()) // off += 4 | |
TESTB(exhausted.As8(), exhausted.As8()) // any br[i].ofs < 4? | |
JNZ(LabelRef("done")) | |
CMPB(off.As8(), U8(0)) | |
JNZ(LabelRef("main_loop")) | |
Label("done") | |
offsetComp, err := ReturnIndex(0).Resolve() | |
if err != nil { | |
panic(err) | |
} | |
MOVB(off.As8(), offsetComp.Addr) | |
RET() | |
} | |
func (d decompress4x) decodeFourValues(id int, br, peekBits, table, buffer, off, exhausted reg.GPVirtual) { | |
brValue, brBitsRead := d.fillFast32(id+1000, 32, br, exhausted) | |
decompress := func(valID int, outByte reg.Register) { | |
CX := reg.CL | |
val := GP64() | |
Commentf("val%d := br%d.peekTopBits(peekBits)", valID, id) | |
MOVQ(brValue, val.As64()) | |
MOVQ(peekBits, CX.As64()) | |
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask | |
Commentf("v%d := table[val0&mask]", valID) | |
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, CX.As16()) | |
Commentf("br%d.advance(uint8(v%d.entry)", id, valID) | |
MOVB(CX.As8H(), outByte) // BL = uint8(v0.entry >> 8) | |
MOVBQZX(CX.As8(), CX.As64()) | |
if d.bmi2 { | |
SHLXQ(CX.As64(), brValue, brValue) // value <<= n | |
} else { | |
SHLQ(CX, brValue) // value <<= n | |
} | |
ADDQ(CX.As64(), brBitsRead) // bits_read += n | |
} | |
out := reg.RAX // Fixed since we need 8H | |
decompress(0, out.As8L()) | |
decompress(1, out.As8H()) | |
BSWAPL(out.As32()) | |
decompress(2, out.As8H()) | |
decompress(3, out.As8L()) | |
BSWAPL(out.As32()) | |
Comment("these four writes get coalesced") | |
Comment("buf[stream][off] = uint8(v0.entry >> 8)") | |
Comment("buf[stream][off+1] = uint8(v1.entry >> 8)") | |
Comment("buf[stream][off+2] = uint8(v2.entry >> 8)") | |
Comment("buf[stream][off+3] = uint8(v3.entry >> 8)") | |
MOVL(out.As32(), Mem{Base: buffer, Index: off, Scale: 1, Disp: id * buffoff}) | |
Comment("update the bitreader reader structure") | |
MOVQ(brValue, Mem{Base: br, Disp: bitReader_value}) | |
MOVB(brBitsRead.As8(), Mem{Base: br, Disp: bitReader_bitsRead}) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment