Skip to content

Instantly share code, notes, and snippets.

@HalidCisse
Forked from dimzon/zstd.cs
Created April 1, 2018 00:44
Show Gist options
  • Save HalidCisse/1c1918fbe0e1b7f115ced05b79cefefc to your computer and use it in GitHub Desktop.
Save HalidCisse/1c1918fbe0e1b7f115ced05b79cefefc to your computer and use it in GitHub Desktop.
zstd streaming api wrapper
using System;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Runtime.InteropServices;
namespace ZstdCompression
{
public static class Zstd
{
private static void Call(UIntPtr x)
{
Debug.WriteLine(x);
if (ZSTD_isError(x) == 0) return;
throw new IOException(ZSTD_getErrorName(x));
}
public class ZStdOutputStream:Stream
{
private readonly Stream _outputStream;
private readonly bool _leaveOpen;
private readonly int _inputBufferSize;
private readonly IntPtr _zst;
private readonly byte[] _outputBufferArray;
private bool _closed;
public ZStdOutputStream(Stream outputStream, int level=6, bool leaveOpen=false)
{
_outputStream = outputStream;
_leaveOpen = leaveOpen;
_zst = ZSTD_createCStream();
Call(ZSTD_initCStream(_zst, level));
_inputBufferSize = (int) ZSTD_CStreamInSize().ToUInt32();
_outputBuffer.Size = ZSTD_CStreamOutSize();
_outputBufferArray = new byte[(int)_outputBuffer.Size.ToUInt32()];
}
public override void Close()
{
if(_closed) return;
Flush(false);
Flush(true);
_outputStream.Flush();
Call(ZSTD_freeCStream(_zst));
if(!_leaveOpen) _outputStream.Close();
_closed = true;
base.Close();
}
public override void Flush()
{
Flush(false);
_outputStream.Flush();
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override int Read(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
private void Flush(bool end)
{
var alloc2 = GCHandle.Alloc(_outputBufferArray, GCHandleType.Pinned);
try
{
_outputBuffer.Data = Marshal.UnsafeAddrOfPinnedArrayElement(_outputBufferArray, 0);
_outputBuffer.Position = UIntPtr.Zero;
Call(end ? ZSTD_endStream(_zst, _outputBuffer) : ZSTD_flushStream(_zst, _outputBuffer));
_outputStream.Write(_outputBufferArray, 0, (int)_outputBuffer.Position.ToUInt32());
}
finally
{
alloc2.Free();
}
}
private readonly Buffer _inputBuffer = new Buffer();
private readonly Buffer _outputBuffer = new Buffer();
public override void Write(byte[] buffer, int offset, int count)
{
var alloc1 = GCHandle.Alloc(buffer, GCHandleType.Pinned);
var alloc2 = GCHandle.Alloc(_outputBufferArray, GCHandleType.Pinned);
try
{
_outputBuffer.Data = Marshal.UnsafeAddrOfPinnedArrayElement(_outputBufferArray, 0);
_outputBuffer.Position=UIntPtr.Zero;
while (count > 0)
{
var size = Math.Min(count, _inputBufferSize);
_outputBuffer.Position = UIntPtr.Zero;
_inputBuffer.Data = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset);
_inputBuffer.Position = UIntPtr.Zero;
_inputBuffer.Size = new UIntPtr((uint) size);
Call(ZSTD_compressStream(_zst, _outputBuffer, _inputBuffer));
size = (int) _inputBuffer.Position.ToUInt32();
_outputStream.Write(_outputBufferArray, 0, (int) _outputBuffer.Position.ToUInt32());
count -= size;
offset += size;
}
}
finally
{
alloc1.Free();
alloc2.Free();
}
}
public override bool CanRead
{
get { return false; }
}
public override bool CanSeek
{
get { return false; }
}
public override bool CanWrite
{
get { return _outputStream.CanWrite; }
}
public override long Length
{
get { return 0; }
}
public override long Position
{
get { return 0; }
set
{
throw new NotSupportedException();
}
}
}
public class ZStdInputStream:Stream
{
private readonly Stream _inputStream;
private readonly bool _leaveOpen;
private readonly IntPtr _zst;
private readonly byte[] _inputBufferArray;
public ZStdInputStream(Stream inputStream, bool leaveOpen=false)
{
_inputStream = inputStream;
_leaveOpen = leaveOpen;
_zst = ZSTD_createDStream();
Call(ZSTD_initDStream(_zst));
_inputBuffer.Size = ZSTD_DStreamInSize();
_inputBufferArray = new byte[(int) _inputBuffer.Size.ToUInt32()];
_outputBuffer.Size = ZSTD_DStreamOutSize();
}
private bool _closed;
public override void Close()
{
if(_closed) return;
Call(ZSTD_freeDStream(_zst));
if(!_leaveOpen) _inputStream.Close();
_closed = true;
base.Close();
}
public override void Flush()
{
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
private int _inputArrayPosition;
private int _inputArraySize;
private bool _depleted;
public override int Read(byte[] buffer, int offset, int count)
{
if (count == 0) return 0;
var retVal = 0;
var alloc1 = GCHandle.Alloc(_inputBufferArray, GCHandleType.Pinned);
var alloc2 = GCHandle.Alloc(buffer, GCHandleType.Pinned);
try
{
while (count > 0)
{
var left = _inputArraySize - _inputArrayPosition;
if (left <= 0 && !_depleted)
{
_inputArrayPosition = 0;
_inputArraySize = left = _inputStream.Read(_inputBufferArray, 0, _inputBufferArray.Length);
// no more data at all
if (left <= 0)
{
left = 0;
_depleted = true;
}
}
_inputBuffer.Position = UIntPtr.Zero;
if (_depleted)
{
_inputBuffer.Size = UIntPtr.Zero;
_inputBuffer.Data = IntPtr.Zero;
}
else
{
_inputBuffer.Size = new UIntPtr((uint)left);
_inputBuffer.Data = Marshal.UnsafeAddrOfPinnedArrayElement(_inputBufferArray, _inputArrayPosition);
}
_outputBuffer.Position = UIntPtr.Zero;
_outputBuffer.Size = new UIntPtr((uint)count);
_outputBuffer.Data = Marshal.UnsafeAddrOfPinnedArrayElement(buffer, offset);
Call(ZSTD_decompressStream(_zst, _outputBuffer, _inputBuffer));
var bytesProduced = (int)_outputBuffer.Position.ToUInt32();
if(bytesProduced==0 && _depleted) break;
retVal += bytesProduced;
count -= bytesProduced;
offset += bytesProduced;
if(_depleted) continue;
var bytesConsumed = (int)_inputBuffer.Position.ToUInt32();
_inputArrayPosition += bytesConsumed;
}
return retVal;
}
finally
{
alloc1.Free();
alloc2.Free();
}
}
private readonly Buffer _inputBuffer = new Buffer();
private readonly Buffer _outputBuffer = new Buffer();
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
public override bool CanRead
{
get { return _inputStream.CanRead; }
}
public override bool CanSeek
{
get { return false; }
}
public override bool CanWrite
{
get { return false; }
}
public override long Length
{
get { return 0; }
}
public override long Position
{
get { return 0; }
set { }
}
}
[StructLayout(LayoutKind.Sequential)]
[SuppressMessage("ReSharper", "NotAccessedField.Local")]
private sealed class Buffer
{
public IntPtr Data;
public UIntPtr Size;
public UIntPtr Position;
}
// https://github.com/facebook/zstd/blob/dev/lib/zstd.h
// https://facebook.github.io/zstd/zstd_manual.html
// https://github.com/facebook/zstd/blob/master/doc/zstd_compression_format.md
public static bool IsZstdStream(byte[] buffBytes, int buffLen)
{
//0xFD2FB528 LE
return buffLen > 3
&& buffBytes[0] == 0x28
&& buffBytes[1] == 0xB5
&& buffBytes[2] == 0x2F
&& buffBytes[3] == 0xFD;
}
private const string DllName = "libzstd";
[DllImport(DllName, EntryPoint = "ZSTD_maxCLevel", CallingConvention = CallingConvention.Cdecl)]
public static extern int GetMaxCompessionLevel();
[DllImport(DllName, EntryPoint = "ZSTD_versionNumber", CallingConvention = CallingConvention.Cdecl)]
public static extern int GetVersionNumber();
//[DllImport(DllName, EntryPoint = "ZSTD_versionString", CallingConvention = CallingConvention.Cdecl)]
//public static extern string ZSTD_versionString();
public static string GetVersionString()
{
var n = GetVersionNumber();
return string.Format("{0}.{1}.{2}", n/10000, (n%10000)/100, n%100);
}
[DllImport(DllName, EntryPoint = "ZSTD_isError", CallingConvention = CallingConvention.Cdecl)]
private static extern int ZSTD_isError(UIntPtr code);
[DllImport(DllName, EntryPoint = "ZSTD_getErrorName", CallingConvention = CallingConvention.Cdecl)]
private static extern string ZSTD_getErrorName(UIntPtr code);
[DllImport(DllName, EntryPoint = "ZSTD_createCStream", CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr ZSTD_createCStream();
[DllImport(DllName, EntryPoint = "ZSTD_freeCStream", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_freeCStream(IntPtr zcs);
[DllImport(DllName, EntryPoint = "ZSTD_initCStream", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_initCStream(IntPtr zcs, int compressionLevel);
[DllImport(DllName, EntryPoint = "ZSTD_compressStream", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_compressStream(IntPtr zcs,
[MarshalAs(UnmanagedType.LPStruct)] Buffer outputBuffer,
[MarshalAs(UnmanagedType.LPStruct)] Buffer inputBuffer);
[DllImport(DllName, EntryPoint = "ZSTD_flushStream", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_flushStream(IntPtr zcs,
[MarshalAs(UnmanagedType.LPStruct)] Buffer outputBuffer);
[DllImport(DllName, EntryPoint = "ZSTD_endStream", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_endStream(IntPtr zcs,
[MarshalAs(UnmanagedType.LPStruct)] Buffer outputBuffer);
[DllImport(DllName, EntryPoint = "ZSTD_CStreamInSize", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_CStreamInSize();
[DllImport(DllName, EntryPoint = "ZSTD_CStreamOutSize", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_CStreamOutSize();
[DllImport(DllName, EntryPoint = "ZSTD_createDStream", CallingConvention = CallingConvention.Cdecl)]
private static extern IntPtr ZSTD_createDStream();
[DllImport(DllName, EntryPoint = "ZSTD_freeDStream", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_freeDStream(IntPtr zcs);
[DllImport(DllName, EntryPoint = "ZSTD_initDStream", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_initDStream(IntPtr zcs);
[DllImport(DllName, EntryPoint = "ZSTD_decompressStream", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_decompressStream(IntPtr zcs,
[MarshalAs(UnmanagedType.LPStruct)] Buffer outputBuffer,
[MarshalAs(UnmanagedType.LPStruct)] Buffer inputBuffer);
[DllImport(DllName, EntryPoint = "ZSTD_DStreamInSize", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_DStreamInSize();
[DllImport(DllName, EntryPoint = "ZSTD_CStreamOutSize", CallingConvention = CallingConvention.Cdecl)]
private static extern UIntPtr ZSTD_DStreamOutSize();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment