Skip to content

Instantly share code, notes, and snippets.

@smourier
Created August 31, 2024 13:27
Show Gist options
  • Save smourier/c6d83ac8c7294a44086976c5eaa350ce to your computer and use it in GitHub Desktop.
Save smourier/c6d83ac8c7294a44086976c5eaa350ce to your computer and use it in GitHub Desktop.
On-Demand IDataObject
using System;
using System.Collections.Generic;
using System.Drawing;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices.ComTypes;
using System.Windows.Forms;
namespace OnDemandDataObject;
internal class Program
{
[STAThread]
static void Main()
{
var da = new DataObject();
da.SetOnDemandFiles([
FileDescriptor.FromFile(@"d:\temp\first.pdf"),
FileDescriptor.FromFile(@"d:\temp\second.pdf")
// etc.
]);
da.SetToClipboard();
MessageBox.Show("Press ok to remove files from the clipboard"); // runs message pump
}
}
public class FileDescriptor()
{
public virtual FILEDESCRIPTOR Descriptor { get; set; }
public virtual Func<Stream>? GetStream { get; set; } // this is what's called when a stream is asked for reading (on paste action)
public static FileDescriptor FromFile(string filePath, Func<Stream>? getStream = null) { ArgumentNullException.ThrowIfNull(filePath); return FromFile(new FileInfo(filePath), getStream); }
public static FileDescriptor FromFile(FileInfo info, Func<Stream>? getStream = null)
{
ArgumentNullException.ThrowIfNull(info);
var fd = new FileDescriptor();
fd.GetStream ??= () => new FileStream(info.FullName, FileMode.Open, FileAccess.Read, FileShare.ReadWrite);
fd.Descriptor = FILEDESCRIPTOR.FromFileInfo(info);
return fd;
}
}
public class DataObject : DataObject.IDataObject, IDisposable
{
private readonly List<(FORMATETC, STGMEDIUM, bool)> _data = [];
private bool _disposedValue;
public virtual void SetToClipboard() // can't use winforms with a custom IDataObject
{
OleInitialize(0);
OleSetClipboard(this);
}
public virtual void SetOnDemandFiles(IReadOnlyList<FileDescriptor> files)
{
ArgumentNullException.ThrowIfNull(files);
if (files.Count == 0)
return;
if (files.Any(d => d.GetStream == null))
throw new ArgumentException(null, nameof(files));
// build CFSTR_FILEDESCRIPTORW
var elementSize = Marshal.SizeOf<FILEDESCRIPTOR>();
var size = Marshal.SizeOf<int>() + elementSize * files.Count;
var fileDescriptorsPtr = Marshal.AllocHGlobal(size);
var current = fileDescriptorsPtr;
Marshal.WriteInt32(current, files.Count);
current += Marshal.SizeOf<int>();
foreach (var file in files)
{
Marshal.StructureToPtr(file.Descriptor, current, false);
current += elementSize;
}
try
{
// set CFSTR_FILEDESCRIPTORW
var fmt = new FORMATETC { dwAspect = DVASPECT.DVASPECT_CONTENT, cfFormat = (short)RegisterClipboardFormat(CFSTR_FILEDESCRIPTORW), lindex = -1, tymed = TYMED.TYMED_HGLOBAL };
var medium = new STGMEDIUM { tymed = fmt.tymed, unionmember = fileDescriptorsPtr };
((IDataObject)this).SetData(ref fmt, ref medium, true);
// set all CFSTR_FILECONTENTS
var format = RegisterClipboardFormat(CFSTR_FILECONTENTS);
fmt = new FORMATETC { dwAspect = DVASPECT.DVASPECT_CONTENT, cfFormat = (short)format, tymed = TYMED.TYMED_ISTREAM };
medium = new STGMEDIUM { tymed = fmt.tymed };
for (var i = 0; i < files.Count; i++)
{
fmt.lindex = i;
var stream = new ReadStream(files[i]);
var unk = Marshal.GetComInterfaceForObject(stream, typeof(IStream));
try
{
medium.unionmember = unk;
medium.pUnkForRelease = stream;
((IDataObject)this).SetData(ref fmt, ref medium, true);
}
finally
{
Marshal.Release(unk);
}
}
}
catch // free only on error
{
Marshal.FreeHGlobal(fileDescriptorsPtr);
throw;
}
}
protected virtual void Dispose(bool disposing)
{
if (!_disposedValue)
{
if (disposing)
{
foreach (var data in _data.Where(d => d.Item3))
{
var medium = data.Item2;
ReleaseStgMedium(ref medium);
}
_data.Clear();
}
_disposedValue = true;
}
}
~DataObject() { Dispose(disposing: false); }
public void Dispose() { Dispose(disposing: true); GC.SuppressFinalize(this); }
int IDataObject.GetData(ref FORMATETC pformatetcIn, out STGMEDIUM pmedium)
{
foreach (var data in _data)
{
if (data.Item1.cfFormat == pformatetcIn.cfFormat && data.Item1.lindex == pformatetcIn.lindex)
{
var medium = data.Item2;
return CopyStgMediumOut(ref medium, out pmedium);
}
}
pmedium = new();
return DV_E_FORMATETC;
}
int IDataObject.GetDataHere(ref FORMATETC pformatetcIn, ref STGMEDIUM pmedium)
{
foreach (var data in _data)
{
if (data.Item1.Equals(pformatetcIn))
{
var medium = data.Item2;
medium.pUnkForRelease = 0;
return CopyStgMediumRef(ref medium, ref pmedium);
}
}
return DV_E_FORMATETC;
}
int IDataObject.QueryGetData(ref FORMATETC pformatetc)
{
foreach (var data in _data)
{
if (data.Item1.Equals(pformatetc))
return 0;
}
return DV_E_FORMATETC;
}
int IDataObject.SetData(ref FORMATETC pformatetc, ref STGMEDIUM pmedium, bool fRelease)
{
foreach (var data in _data.ToArray())
{
if (data.Item1.cfFormat == pformatetc.cfFormat && data.Item1.lindex == pformatetc.lindex)
{
_data.Remove(data);
}
}
_data.Add((pformatetc, pmedium, fRelease));
return 0;
}
int IDataObject.GetCanonicalFormatEtc(ref FORMATETC pformatectIn, out FORMATETC pformatetcOut) => throw new NotImplementedException();
int IDataObject.EnumFormatEtc(DATADIR dwDirection, out IEnumFORMATETC ppenumFormatEtc) { ppenumFormatEtc = new EnumFORMATETC(this, dwDirection); return 0; }
int IDataObject.DAdvise(ref FORMATETC pformatetc, uint advf, IAdviseSink pAdvSink, out uint dwConnection) => throw new NotImplementedException();
int IDataObject.DUnadvise(uint dwConnection) => throw new NotImplementedException();
int IDataObject.EnumDAdvise(out IEnumSTATDATA ppenumAdvise) => throw new NotImplementedException();
private class EnumFORMATETC(DataObject dataObject, DATADIR direction) : IEnumFORMATETC
{
public int Index { get; set; }
public int Next(int celt, FORMATETC[] rgelt, int[] pceltFetched)
{
if (pceltFetched != null) { pceltFetched[0] = 0; }
if (Index >= dataObject._data.Count)
return 1;
var fetched = 0;
while (fetched < celt && fetched < dataObject._data.Count)
{
rgelt[fetched] = dataObject._data[Index].Item1;
Index++;
fetched++;
}
if (pceltFetched != null) { pceltFetched[0] = fetched; }
return fetched == celt ? 0 : 1;
}
public int Reset() { Index = 0; return 0; }
public void Clone(out IEnumFORMATETC newEnum) => newEnum = new EnumFORMATETC(dataObject, direction);
public int Skip(int celt) => throw new NotImplementedException();
}
private class ReadStream : IStream
{
private readonly FileDescriptor _descriptor;
private readonly Lazy<Stream> _stream;
public ReadStream(FileDescriptor descriptor)
{
_descriptor = descriptor;
_stream = new Lazy<Stream>(() => _descriptor.GetStream!() ?? throw new InvalidOperationException());
}
private Stream Stream => _stream.Value;
// Explorer calls here
void IStream.Read(byte[] pv, int cb, nint pcbRead)
{
var read = Stream.Read(pv, 0, cb);
if (pcbRead != 0) { Marshal.WriteInt32(pcbRead, read); }
}
void IStream.Seek(long dlibMove, int dwOrigin, nint plibNewPosition)
{
var newPos = Stream.Seek(dlibMove, (SeekOrigin)dwOrigin);
if (plibNewPosition != 0) { Marshal.WriteInt64(plibNewPosition, newPos); }
}
public void Stat(out STATSTG pstatstg, int grfStatFlag)
{
const int STGTY_STREAM = 2;
const int STGM_READWRITE = 2;
const int STGM_WRITE = 1;
var stream = Stream;
pstatstg = new STATSTG { type = STGTY_STREAM, cbSize = stream.Length, };
const int STATFLAG_NONAME = 1;
if ((grfStatFlag & STATFLAG_NONAME) == 0) pstatstg.pwcsName = _descriptor.Descriptor.cFileName;
pstatstg.atime = ToFileTime(_descriptor.Descriptor.ftLastAccessTime);
pstatstg.ctime = ToFileTime(_descriptor.Descriptor.ftCreationTime);
pstatstg.mtime = ToFileTime(_descriptor.Descriptor.ftLastWriteTime);
pstatstg.clsid = _descriptor.Descriptor.clsid;
if (stream.CanRead && stream.CanWrite)
{
pstatstg.grfMode |= STGM_READWRITE;
return;
}
if (stream.CanWrite) { pstatstg.grfMode |= STGM_WRITE; }
}
// Office (outlook, word, etc.) calls here
public void CopyTo(IStream pstm, long cb, nint pcbRead, nint pcbWritten)
{
ArgumentNullException.ThrowIfNull(pstm);
var count = 0L;
var bytes = new byte[0x14000]; // 81920 under loh
do
{
var max = (int)Math.Min(cb - count, bytes.Length);
var read = Stream.Read(bytes, 0, max);
if (read == 0)
break;
pstm.Write(bytes, read, 0);
count += read;
if (count == cb)
break;
}
while (true);
if (pcbRead != 0) Marshal.WriteInt64(pcbRead, count);
if (pcbWritten != 0) Marshal.WriteInt64(pcbWritten, count);
pstm.Commit(0); // STGC_DEFAULT
Marshal.FinalReleaseComObject(pstm); // we must do this otherwise Office doesn't like it
}
public void Commit(int grfCommitFlags) => Stream.Flush();
void IStream.Write(byte[] pv, int cb, nint pcbWritten) => throw new NotImplementedException();
void IStream.SetSize(long libNewSize) => throw new NotImplementedException();
void IStream.Revert() => throw new NotImplementedException();
void IStream.LockRegion(long libOffset, long cb, int dwLockType) => throw new NotImplementedException();
void IStream.UnlockRegion(long libOffset, long cb, int dwLockType) => throw new NotImplementedException();
void IStream.Clone(out IStream ppstm) => throw new NotImplementedException();
private static FILETIME ToFileTime(long fileTime) => new() { dwLowDateTime = (int)(fileTime & uint.MaxValue), dwHighDateTime = (int)(fileTime >> 32) };
}
private const int DV_E_FORMATETC = unchecked((int)0x80040064);
private const string CFSTR_FILEDESCRIPTORW = "FileGroupDescriptorW";
private const string CFSTR_FILECONTENTS = "FileContents";
[DllImport("user32", CharSet = CharSet.Unicode)]
private static extern int RegisterClipboardFormat(string format);
[DllImport("ole32")]
private static extern int OleSetClipboard(IDataObject pDataObj);
[DllImport("ole32")]
private static extern int OleInitialize(nint pvReserved);
[DllImport("ole32")]
private static extern void ReleaseStgMedium(ref STGMEDIUM medium);
[DllImport("urlmon", EntryPoint = "CopyStgMedium")]
private static extern int CopyStgMediumOut(ref STGMEDIUM pcstgmedSrc, out STGMEDIUM pstgmedDest);
[DllImport("urlmon", EntryPoint = "CopyStgMedium")]
private static extern int CopyStgMediumRef(ref STGMEDIUM pcstgmedSrc, ref STGMEDIUM pstgmedDest);
[ComImport, Guid("0000010E-0000-0000-C000-000000000046"), InterfaceType(ComInterfaceType.InterfaceIsIUnknown)]
private interface IDataObject // redefined so we don't need to throw
{
[PreserveSig]
int GetData(ref FORMATETC pformatetcIn, out STGMEDIUM pmedium);
[PreserveSig]
int GetDataHere(ref FORMATETC pformatetcIn, ref STGMEDIUM pmedium);
[PreserveSig]
int QueryGetData(ref FORMATETC pformatetc);
[PreserveSig]
int GetCanonicalFormatEtc(ref FORMATETC pformatectIn, out FORMATETC pformatetcOut);
[PreserveSig]
int SetData(ref FORMATETC pformatetc, ref STGMEDIUM pmedium, bool fRelease);
[PreserveSig]
int EnumFormatEtc(DATADIR dwDirection, out IEnumFORMATETC ppenumFormatEtc);
[PreserveSig]
int DAdvise(ref FORMATETC pformatetc, uint advf, IAdviseSink pAdvSink, out uint dwConnection);
[PreserveSig]
int DUnadvise(uint dwConnection);
[PreserveSig]
int EnumDAdvise(out IEnumSTATDATA ppenumAdvise);
}
}
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Unicode)]
public struct FILEDESCRIPTOR
{
public FD dwFlags;
public Guid clsid;
public Size sizel;
public Point pointl;
public FileAttributes dwFileAttributes;
public long ftCreationTime;
public long ftLastAccessTime;
public long ftLastWriteTime;
public uint nFileSizeHigh;
public uint nFileSizeLow;
[MarshalAs(UnmanagedType.ByValTStr, SizeConst = 260)]
public string cFileName;
public override readonly string ToString() => cFileName;
public static FILEDESCRIPTOR FromFileInfo(FileInfo info)
{
ArgumentNullException.ThrowIfNull(info);
var fd = new FILEDESCRIPTOR { dwFlags = FD.FD_UNICODE, cFileName = info.Name };
if (info.Exists)
{
fd.dwFlags |= FD.FD_FILESIZE | FD.FD_ATTRIBUTES;
fd.nFileSizeLow = (uint)(info.Length & uint.MaxValue);
fd.nFileSizeHigh = (uint)(info.Length >> 32);
fd.dwFileAttributes = info.Attributes;
if (IsValidFileTime(info.CreationTimeUtc))
{
fd.ftCreationTime = info.CreationTimeUtc.ToFileTimeUtc();
fd.dwFlags |= FD.FD_CREATETIME;
}
if (IsValidFileTime(info.LastAccessTimeUtc))
{
fd.ftLastAccessTime = info.LastAccessTimeUtc.ToFileTimeUtc();
fd.dwFlags |= FD.FD_ACCESSTIME;
}
if (IsValidFileTime(info.LastWriteTimeUtc))
{
fd.ftLastWriteTime = info.LastWriteTimeUtc.ToFileTimeUtc();
fd.dwFlags |= FD.FD_WRITESTIME;
}
}
const long fileTimeOffset = 504911232000000000; // daysTo1601 * ticksPerDay;
static long ToFileTime(DateTime dt) => (dt.Kind != DateTimeKind.Utc ? dt.ToUniversalTime().Ticks : dt.Ticks) - fileTimeOffset;
static bool IsValidFileTime(DateTime dt) => ToFileTime(dt) >= 0;
return fd;
}
}
[Flags]
public enum FD
{
FD_CLSID = 0x00000001,
FD_SIZEPOINT = 0x00000002,
FD_ATTRIBUTES = 0x00000004,
FD_CREATETIME = 0x00000008,
FD_ACCESSTIME = 0x00000010,
FD_WRITESTIME = 0x00000020,
FD_FILESIZE = 0x00000040,
FD_PROGRESSUI = 0x00004000,
FD_LINKUI = 0x00008000,
FD_UNICODE = unchecked((int)0x80000000),
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment