Last active
May 6, 2022 12:49
-
-
Save klauspost/306ed219fb5e275d7c9abb87c9a7b142 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//go:build amd64 && !appengine && !noasm && gc | |
// +build amd64,!appengine,!noasm,gc | |
// This file contains the specialisation of Decoder.Decompress4X | |
// that uses an asm implementation of its main loop. | |
package huff0 | |
import ( | |
"errors" | |
"fmt" | |
) | |
// decompress4x_main_loop_x86 is an x86 assembler implementation | |
// of Decompress4X when tablelog > 8. | |
//go:noescape | |
func decompress4x_main_loop_amd64(ctx *decompress4xContext) | |
// decompress4x_8b_loop_x86 is an x86 assembler implementation | |
// of Decompress4X when tablelog <= 8 which decodes 4 entries | |
// per loop. | |
//go:noescape | |
func decompress4x_8b_main_loop_amd64(ctx *decompress4xContext) | |
// fallback8BitSize is the size where using Go version is faster. | |
const fallback8BitSize = 800 | |
type decompress4xContext struct { | |
pbr0 *bitReaderShifted | |
pbr1 *bitReaderShifted | |
pbr2 *bitReaderShifted | |
pbr3 *bitReaderShifted | |
peekBits uint8 | |
out *byte | |
dstEvery int | |
tbl *dEntrySingle | |
decoded int | |
limit *byte | |
} | |
// Decompress4X will decompress a 4X encoded stream. | |
// The length of the supplied input must match the end of a block exactly. | |
// The *capacity* of the dst slice must match the destination size of | |
// the uncompressed data exactly. | |
func (d *Decoder) Decompress4X(dst, src []byte) ([]byte, error) { | |
if len(d.dt.single) == 0 { | |
return nil, errors.New("no table loaded") | |
} | |
if len(src) < 6+(4*1) { | |
return nil, errors.New("input too small") | |
} | |
use8BitTables := d.actualTableLog <= 8 | |
if cap(dst) < fallback8BitSize && use8BitTables { | |
return d.decompress4X8bit(dst, src) | |
} | |
var br [4]bitReaderShifted | |
// Decode "jump table" | |
start := 6 | |
for i := 0; i < 3; i++ { | |
length := int(src[i*2]) | (int(src[i*2+1]) << 8) | |
if start+length >= len(src) { | |
return nil, errors.New("truncated input (or invalid offset)") | |
} | |
err := br[i].init(src[start : start+length]) | |
if err != nil { | |
return nil, err | |
} | |
start += length | |
} | |
err := br[3].init(src[start:]) | |
if err != nil { | |
return nil, err | |
} | |
// destination, offset to match first output | |
dstSize := cap(dst) | |
dst = dst[:dstSize] | |
out := dst | |
dstEvery := (dstSize + 3) / 4 | |
const tlSize = 1 << tableLogMax | |
const tlMask = tlSize - 1 | |
single := d.dt.single[:tlSize] | |
// Use temp table to avoid bound checks/append penalty. | |
var decoded int | |
const debug = false | |
ctx := decompress4xContext{ | |
pbr0: &br[0], | |
pbr1: &br[1], | |
pbr2: &br[2], | |
pbr3: &br[3], | |
peekBits: uint8((64 - d.actualTableLog) & 63), // see: bitReaderShifted.peekBitsFast() | |
out: &out[0], | |
dstEvery: dstEvery, | |
tbl: &single[0], | |
limit: &out[dstEvery-4], // Always stop decoding when first buffer gets here to avoid writing OOB on last. | |
} | |
// Decode 2 values from each decoder/loop. | |
if !(br[0].off < 4 || br[1].off < 4 || br[2].off < 4 || br[3].off < 4) { | |
if use8BitTables { | |
decompress4x_8b_main_loop_amd64(&ctx) | |
} else { | |
decompress4x_main_loop_amd64(&ctx) | |
} | |
decoded = ctx.decoded | |
out = out[decoded/4:] | |
} | |
// Decode remaining. | |
remainBytes := dstEvery - (decoded / 4) | |
for i := range br { | |
offset := dstEvery * i | |
endsAt := offset + remainBytes | |
if endsAt > len(out) { | |
endsAt = len(out) | |
} | |
br := &br[i] | |
bitsLeft := br.remaining() | |
for bitsLeft > 0 { | |
br.fill() | |
if offset >= endsAt { | |
return nil, errors.New("corruption detected: stream overrun 4") | |
} | |
// Read value and increment offset. | |
val := br.peekBitsFast(d.actualTableLog) | |
v := single[val&tlMask].entry | |
nBits := uint8(v) | |
br.advance(nBits) | |
bitsLeft -= uint(nBits) | |
out[offset] = uint8(v >> 8) | |
offset++ | |
} | |
if offset != endsAt { | |
return nil, fmt.Errorf("corruption detected: short output block %d, end %d != %d", i, offset, endsAt) | |
} | |
decoded += offset - dstEvery*i | |
err = br.close() | |
if err != nil { | |
return nil, err | |
} | |
} | |
if dstSize != decoded { | |
return nil, errors.New("corruption detected: short output block") | |
} | |
return dst, nil | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
//go:generate go run gen.go -out ../decompress_amd64.s -pkg=huff0 | |
import ( | |
"flag" | |
"fmt" | |
"strconv" | |
_ "github.com/klauspost/compress" | |
. "github.com/mmcloughlin/avo/build" | |
"github.com/mmcloughlin/avo/buildtags" | |
. "github.com/mmcloughlin/avo/operand" | |
"github.com/mmcloughlin/avo/reg" | |
) | |
func main() { | |
flag.Parse() | |
ConstraintExpr("amd64,!appengine,!noasm,gc") | |
decompress := decompress4x{} | |
decompress.generateProcedure("decompress4x_main_loop_amd64") | |
decompress.generateProcedure4x8bit("decompress4x_8b_main_loop_amd64") | |
Generate() | |
} | |
type decompress4x struct { | |
bmi2 bool | |
} | |
func (d decompress4x) generateProcedure(name string) { | |
Package("github.com/klauspost/compress/huff0") | |
TEXT(name, 0, "func(ctx* decompress4xContext)") | |
Doc(name+" is an x86 assembler implementation of Decompress4X when tablelog > 8.decodes a sequence", "") | |
Pragma("noescape") | |
exhausted := GP64() | |
XORQ(exhausted.As64(), exhausted.As64()) // exhausted = false | |
limitPtr := AllocLocal(8) | |
bufferOrigin := GP64() | |
peekBits := GP64() | |
buffer := GP64() | |
dstEvery := GP64() | |
table := GP64() | |
br0 := GP64() | |
br1 := GP64() | |
br2 := GP64() | |
br3 := GP64() | |
Comment("Preload values") | |
{ | |
ctx := Dereference(Param("ctx")) | |
Load(ctx.Field("peekBits"), peekBits) | |
Load(ctx.Field("out"), buffer) | |
MOVQ(buffer, bufferOrigin) | |
limit := Load(ctx.Field("limit"), GP64()) | |
MOVQ(limit, limitPtr) | |
Load(ctx.Field("dstEvery"), dstEvery) | |
Load(ctx.Field("tbl"), table) | |
Load(ctx.Field("pbr0"), br0) | |
Load(ctx.Field("pbr1"), br1) | |
Load(ctx.Field("pbr2"), br2) | |
Load(ctx.Field("pbr3"), br3) | |
} | |
Comment("Main loop") | |
Label("main_loop") | |
MOVQ(bufferOrigin, buffer) | |
// Check if we have space | |
CMPQ(buffer, limitPtr) | |
SETGE(exhausted.As8()) | |
d.decodeTwoValues(0, br0, peekBits, table, buffer, exhausted) | |
ADDQ(dstEvery, buffer) | |
d.decodeTwoValues(1, br1, peekBits, table, buffer, exhausted) | |
ADDQ(dstEvery, buffer) | |
d.decodeTwoValues(2, br2, peekBits, table, buffer, exhausted) | |
ADDQ(dstEvery, buffer) | |
d.decodeTwoValues(3, br3, peekBits, table, buffer, exhausted) | |
ADDQ(U8(2), bufferOrigin) // off += 2 | |
TESTB(exhausted.As8(), exhausted.As8()) // any br[i].ofs < 4? | |
JZ(LabelRef("main_loop")) | |
{ | |
ctx := Dereference(Param("ctx")) | |
tmp := Load(ctx.Field("out"), GP64()) | |
decoded := GP64() | |
MOVQ(bufferOrigin, decoded) | |
SUBQ(tmp, decoded) | |
SHLQ(U8(2), decoded) // decoded *= 4 | |
Store(decoded, ctx.Field("decoded")) | |
} | |
RET() | |
} | |
// TODO [wmu]: I believe it's doable in avo, but can't figure out how to deal | |
// with arbitrary pointers to a given type | |
const bitReader_in = 0 | |
const bitReader_off = bitReader_in + 3*8 // {ptr, len, cap} | |
const bitReader_value = bitReader_off + 8 | |
const bitReader_bitsRead = bitReader_value + 8 | |
func (d decompress4x) decodeTwoValues(id int, br, peekBits, table, buffer, exhausted reg.GPVirtual) { | |
brValue, brBitsRead := d.fillFast32(id, 32, br, exhausted) | |
val := GP64() | |
Commentf("val0 := br%d.peekTopBits(peekBits)", id) | |
CX := reg.CL | |
if d.bmi2 { | |
SHRXQ(peekBits, brValue, val.As64()) // val = (value >> peek_bits) & mask | |
} else { | |
MOVQ(brValue, val.As64()) | |
MOVQ(peekBits, CX.As64()) | |
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask | |
} | |
Comment("v0 := table[val0&mask]") | |
v := reg.RDX | |
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, v.As16()) | |
Commentf("br%d.advance(uint8(v0.entry)", id) | |
out := reg.RAX // Fixed since we need 8H | |
MOVB(v.As8H(), out.As8()) // BL = uint8(v0.entry >> 8) | |
MOVBQZX(v.As8(), CX.As64()) | |
if d.bmi2 { | |
SHLXQ(v.As64(), brValue, brValue) // value <<= n | |
} else { | |
SHLQ(CX, brValue) // value <<= n | |
} | |
ADDQ(CX.As64(), brBitsRead) // bits_read += n | |
Commentf("val1 := br%d.peekTopBits(peekBits)", id) | |
if d.bmi2 { | |
SHRXQ(peekBits, brValue, val.As64()) // val = (value >> peek_bits) & mask | |
} else { | |
MOVQ(peekBits, CX.As64()) | |
MOVQ(brValue, val.As64()) | |
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask | |
} | |
Comment("v1 := table[val1&mask]") | |
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, v.As16()) // tmp - v1 | |
Commentf("br%d.advance(uint8(v1.entry))", id) | |
MOVB(v.As8H(), out.As8H()) // BH = uint8(v0.entry >> 8) | |
MOVBQZX(v.As8(), CX.As64()) | |
if d.bmi2 { | |
SHLXQ(v.As64(), brValue, brValue) // value <<= n | |
} else { | |
SHLQ(CX, brValue) // value <<= n | |
} | |
ADDQ(CX.As64(), brBitsRead) // bits_read += n | |
Comment("these two writes get coalesced") | |
Comment("out[stream][off] = uint8(v0.entry >> 8)") | |
Comment("out[stream][off+1] = uint8(v1.entry >> 8)") | |
MOVW(out.As16(), Mem{Base: buffer}) | |
Comment("update the bitrader reader structure") | |
MOVQ(brValue, Mem{Base: br, Disp: bitReader_value}) | |
MOVB(brBitsRead.As8(), Mem{Base: br, Disp: bitReader_bitsRead}) | |
} | |
func (d decompress4x) generateProcedure4x8bit(name string) { | |
Package("github.com/klauspost/compress/huff0") | |
TEXT(name, 0, "func(ctx* decompress4xContext)") | |
Doc(name+" is an x86 assembler implementation of Decompress4X when tablelog > 8.decodes a sequence", "") | |
Pragma("noescape") | |
exhausted := GP64() // Fixed since we need 8H | |
XORQ(exhausted.As64(), exhausted.As64()) // exhausted = false | |
bufferOrigin := AllocLocal(8) | |
limitPtr := AllocLocal(8) | |
peekBits := GP64() | |
buffer := GP64() | |
dstEvery := GP64() | |
table := GP64() | |
br0 := GP64() | |
br1 := GP64() | |
br2 := GP64() | |
br3 := GP64() | |
Comment("Preload values") | |
{ | |
ctx := Dereference(Param("ctx")) | |
Load(ctx.Field("peekBits"), peekBits) | |
Load(ctx.Field("out"), buffer) | |
MOVQ(buffer, bufferOrigin) | |
limit := Load(ctx.Field("limit"), GP64()) | |
MOVQ(limit, limitPtr) | |
Load(ctx.Field("dstEvery"), dstEvery) | |
Load(ctx.Field("tbl"), table) | |
Load(ctx.Field("pbr0"), br0) | |
Load(ctx.Field("pbr1"), br1) | |
Load(ctx.Field("pbr2"), br2) | |
Load(ctx.Field("pbr3"), br3) | |
} | |
Comment("Main loop") | |
Label("main_loop") | |
MOVQ(bufferOrigin, buffer) | |
// Check if we have space | |
CMPQ(buffer, limitPtr) | |
SETGE(exhausted.As8()) | |
d.decodeFourValues(0, br0, peekBits, table, buffer, exhausted) | |
ADDQ(dstEvery, buffer) | |
d.decodeFourValues(1, br1, peekBits, table, buffer, exhausted) | |
ADDQ(dstEvery, buffer) | |
d.decodeFourValues(2, br2, peekBits, table, buffer, exhausted) | |
ADDQ(dstEvery, buffer) | |
d.decodeFourValues(3, br3, peekBits, table, buffer, exhausted) | |
ADDQ(U8(4), bufferOrigin) // off += 4 | |
TESTB(exhausted.As8(), exhausted.As8()) // any br[i].ofs < 4? | |
JZ(LabelRef("main_loop")) | |
{ | |
ctx := Dereference(Param("ctx")) | |
tmp := Load(ctx.Field("out"), GP64()) | |
decoded := GP64() | |
MOVQ(bufferOrigin, decoded) | |
SUBQ(tmp, decoded) | |
SHLQ(U8(2), decoded) // decoded *= 4 | |
Store(decoded, ctx.Field("decoded")) | |
} | |
RET() | |
} | |
func (d decompress4x) decodeFourValues(id int, br, peekBits, table, buffer, exhausted reg.GPVirtual) { | |
brValue, brBitsRead := d.fillFast32(id+1000, 32, br, exhausted) | |
decompress := func(valID int, outByte reg.Register) { | |
CX := reg.CL | |
val := GP64() | |
Commentf("val%d := br%d.peekTopBits(peekBits)", valID, id) | |
if d.bmi2 { | |
SHRXQ(peekBits, brValue, val.As64()) // val = (value >> peek_bits) & mask | |
} else { | |
MOVQ(brValue, val.As64()) | |
MOVQ(peekBits, CX.As64()) | |
SHRQ(CX, val.As64()) // val = (value >> peek_bits) & mask | |
} | |
Commentf("v%d := table[val0&mask]", valID) | |
MOVW(Mem{Base: table, Index: val.As64(), Scale: 2}, CX.As16()) | |
Commentf("br%d.advance(uint8(v%d.entry)", id, valID) | |
MOVB(CX.As8H(), outByte) // BL = uint8(v0.entry >> 8) | |
MOVBQZX(CX.As8(), CX.As64()) | |
if d.bmi2 { | |
SHLXQ(CX.As64(), brValue, brValue) // value <<= n | |
} else { | |
SHLQ(CX, brValue) // value <<= n | |
} | |
ADDQ(CX.As64(), brBitsRead) // bits_read += n | |
} | |
out := reg.RAX // Fixed since we need 8H | |
decompress(0, out.As8L()) | |
decompress(1, out.As8H()) | |
BSWAPL(out.As32()) | |
decompress(2, out.As8H()) | |
decompress(3, out.As8L()) | |
BSWAPL(out.As32()) | |
Comment("these four writes get coalesced") | |
Comment("buf[stream][off] = uint8(v0.entry >> 8)") | |
Comment("buf[stream][off+1] = uint8(v1.entry >> 8)") | |
Comment("buf[stream][off+2] = uint8(v2.entry >> 8)") | |
Comment("buf[stream][off+3] = uint8(v3.entry >> 8)") | |
MOVL(out.As32(), Mem{Base: buffer}) | |
Comment("update the bitreader reader structure") | |
MOVQ(brValue, Mem{Base: br, Disp: bitReader_value}) | |
MOVB(brBitsRead.As8(), Mem{Base: br, Disp: bitReader_bitsRead}) | |
} | |
func (d decompress4x) fillFast32(id, atLeast int, br, exhausted reg.GPVirtual) (brValue, brBitsRead reg.GPVirtual) { | |
if atLeast > 32 { | |
panic(fmt.Sprintf("at least (%d) cannot be >32", atLeast)) | |
} | |
Commentf("br%d.fillFast32()", id) | |
brValue = GP64() | |
brBitsRead = GP64() | |
MOVQ(Mem{Base: br, Disp: bitReader_value}, brValue) | |
MOVBQZX(Mem{Base: br, Disp: bitReader_bitsRead}, brBitsRead) | |
// We must have at least 2 * max tablelog left | |
CMPQ(brBitsRead, U8(64-atLeast)) | |
JBE(LabelRef("skip_fill" + strconv.Itoa(id))) | |
brOffset := GP64() | |
MOVQ(Mem{Base: br, Disp: bitReader_off}, brOffset) | |
SUBQ(U8(32), brBitsRead) // b.bitsRead -= 32 | |
SUBQ(U8(4), brOffset) // b.off -= 4 | |
// v := b.in[b.off-4 : b.off] | |
// v = v[:4] | |
// low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) | |
tmp := GP64() | |
MOVQ(Mem{Base: br, Disp: bitReader_in}, tmp) | |
Comment("b.value |= uint64(low) << (b.bitsRead & 63)") | |
addr := Mem{Base: brOffset, Index: tmp.As64(), Scale: 1} | |
if d.bmi2 { | |
SHLXQ(brBitsRead, addr, tmp.As64()) // tmp = uint32(b.in[b.off:b.off+4]) << (b.bitsRead & 63) | |
} else { | |
CX := reg.CL | |
MOVL(addr, tmp.As32()) // tmp = uint32(b.in[b.off:b.off+4]) | |
MOVQ(brBitsRead, CX.As64()) | |
SHLQ(CX, tmp.As64()) | |
} | |
MOVQ(brOffset, Mem{Base: br, Disp: bitReader_off}) | |
ORQ(tmp.As64(), brValue) | |
{ | |
Commentf("exhausted = exhausted || (br%d.off < 4)", id) | |
CMPQ(brOffset, U8(4)) | |
tmp = GP64() | |
SETLT(tmp.As8()) | |
ORB(tmp.As8(), exhausted.As8()) | |
} | |
Label("skip_fill" + strconv.Itoa(id)) | |
return | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment