Skip to content

Instantly share code, notes, and snippets.

@Washi1337
Last active June 9, 2024 04:39
Show Gist options
  • Save Washi1337/a35acf49b64b07637a3047eec23c4e58 to your computer and use it in GitHub Desktop.
Save Washi1337/a35acf49b64b07637a3047eec23c4e58 to your computer and use it in GitHub Desktop.
Injecting unconventional entry points in a .NET module. Blog post: https://washi.dev/blog/posts/entry-points/
#include <cstdio>
#include <windows.h>
VOID WINAPI TlsCallback(PVOID DllHandle, DWORD Reason, PVOID Reserved)
{
puts("[DynamicLibrary.dll]: TLS Callback");
}
#ifdef _WIN64
#pragma comment (linker, "/INCLUDE:_tls_used")
#pragma comment (linker, "/INCLUDE:tls_callback_func1")
#else
#pragma comment (linker, "/INCLUDE:__tls_used")
#pragma comment (linker, "/INCLUDE:_tls_callback_func1")
#endif
#ifdef _WIN64
#pragma const_seg(".CRT$XLF")
EXTERN_C const
#else
#pragma data_seg(".CRT$XLF")
EXTERN_C
#endif
PIMAGE_TLS_CALLBACK tls_callback_func1 = TlsCallback;
PIMAGE_TLS_CALLBACK tls_callback_end = NULL;
#ifdef _WIN64
#pragma const_seg()
#else
#pragma data_seg()
#endif //_WIN64
extern "C" __declspec(dllexport) void UnmanagedExport()
{
puts("Unmanaged Export");
}
BOOL APIENTRY DllMain(HMODULE hModule, DWORD ul_reason_for_call, LPVOID lpReserved)
{
puts("[DynamicLibrary.dll]: DllMain");
return TRUE;
}
[Flags]
public enum EntryPoints
{
ModuleCctor = 1,
NativeClrEntryPoint = 2,
NativePEEntryPoint = 4,
TlsCallback = 8,
ExternalDll = 16,
All = ModuleCctor | NativeClrEntryPoint | NativePEEntryPoint | TlsCallback | ExternalDll
}
using System.Text;
public static class EntryPointsInjector
{
public static PEFile InjectEntryPoints(string prefix, ModuleDefinition module, EntryPoints includedEntryPoints)
{
// Force 32-bit for now.
module.IsILOnly = false;
module.IsBit32Required = true;
module.PEKind = OptionalHeaderMagic.PE32;
module.MachineType = MachineType.I386;
// Inject .cctor if specified.
if ((includedEntryPoints & EntryPoints.ModuleCctor) != 0)
InjectModuleCctor(prefix, module);
// We're done in .NET land, go into PE mode.
var dllImage = module.ToPEImage();
// Import modules that we will be using.
var ucrtbase = new ImportedModule("ucrtbase.dll");
var puts = new ImportedSymbol(0, "puts");
ucrtbase.Symbols.Add(puts);
dllImage.Imports.Add(ucrtbase);
var mscoree = new ImportedModule("mscoree.dll");
var corDllMain = new ImportedSymbol(0, "_CorDllMain");
mscoree.Symbols.Add(corDllMain);
dllImage.Imports.Add(mscoree);
// Inject external dll import if specified.
if ((includedEntryPoints & EntryPoints.ExternalDll) != 0)
{
var dynamicLibrary = new ImportedModule("DynamicLibrary.dll");
var export = new ImportedSymbol(0, "UnmanagedExport");
dynamicLibrary.Symbols.Add(export);
dllImage.Imports.Add(dynamicLibrary);
}
// Inject native CLR entry point.
var nativeClrEntryPoint = (includedEntryPoints & EntryPoints.NativeClrEntryPoint) != 0
? InjectNativeClrEntryPoint(prefix, dllImage)
: null;
// Inject PE DllMain entry point.
var peEntryPoint = InjectNativePEEntryPoint(prefix, dllImage, (includedEntryPoints & EntryPoints.NativePEEntryPoint) != 0);
// Inject TLS directory if specified.
var tlsSegment = (includedEntryPoints & EntryPoints.TlsCallback) != 0
? InjectTlsCallBack(prefix, dllImage)
: null;
// We're done in PE image land, move lower into PE file land.
var dllFile = new ManagedPEFileBuilder().CreateFile(dllImage);
// Remove original relocs.
dllFile.Sections.Remove(dllFile.Sections.First(s => s.Name == ".reloc"));
// Build up new text section.
var newTextBuilder = new SegmentBuilder();
newTextBuilder.Add(peEntryPoint, 8);
if (nativeClrEntryPoint is not null)
newTextBuilder.Add(nativeClrEntryPoint, 8);
dllFile.Sections.Add(new PESection(
".text2",
SectionFlags.MemoryRead | SectionFlags.MemoryExecute | SectionFlags.ContentCode,
newTextBuilder));
// Add TLS section if necessary.
if (tlsSegment is not null)
{
dllFile.Sections.Add(new PESection(
".tls",
SectionFlags.ContentInitializedData | SectionFlags.MemoryRead | SectionFlags.MemoryWrite | SectionFlags.MemoryExecute,
tlsSegment));
}
// Rebuild imports and relocs.
var imports = RebuildImportDirectories(dllImage, dllFile);
var relocs = RebuildBaseRelocations(dllImage, dllFile);
// Calculate offsets.
dllFile.UpdateHeaders();
// Update entry points and data directories.
dllFile.OptionalHeader.AddressOfEntryPoint = peEntryPoint.Rva;
if (nativeClrEntryPoint is not null)
dllImage.DotNetDirectory!.EntryPoint = nativeClrEntryPoint.Rva;
var directories = dllFile.OptionalHeader.DataDirectories;
directories[(int) DataDirectoryIndex.ImportDirectory] = new(imports.Rva, imports.GetPhysicalSize());
directories[(int) DataDirectoryIndex.IatDirectory] = new(imports.ImportAddressDirectory.Rva, imports.ImportAddressDirectory.GetPhysicalSize());
directories[(int) DataDirectoryIndex.BaseRelocationDirectory] = new(relocs.Rva, relocs.GetPhysicalSize());
if (dllImage.TlsDirectory is not null)
directories[(int) DataDirectoryIndex.TlsDirectory] = new(dllImage.TlsDirectory.Rva, dllImage.TlsDirectory.GetPhysicalSize());
return dllFile;
}
public static void InjectModuleCctor(string prefix, ModuleDefinition module)
{
var cctor = module.GetOrCreateModuleConstructor();
cctor.CilMethodBody!.Instructions.InsertRange(0, new[]
{
new CilInstruction(CilOpCodes.Ldstr, $"[{prefix}]: {cctor.DeclaringType}::.cctor()"),
new CilInstruction(CilOpCodes.Call, module.CorLibTypeFactory.CorLibScope
.CreateTypeReference("System", "Console")
.CreateMemberReference("WriteLine", MethodSignature.CreateStatic(
module.CorLibTypeFactory.Void,
module.CorLibTypeFactory.String))
.ImportWith(module.DefaultImporter))
});
}
private static ISegment InjectNativePEEntryPoint(string prefix, IPEImage image, bool injectAdditionalCode)
{
var puts = image
.Imports.First(m => m.Name == "ucrtbase.dll")
.Symbols.First(s => s.Name == "puts");
var corDllMain = image
.Imports.First(m => m.Name == "mscoree.dll")
.Symbols.First(s => s.Name == "_CorDllMain");
var messageSegment = new DataSegment(Encoding.ASCII.GetBytes($"[{prefix}]: DllMain"));
var result = new SegmentBuilder();
if (injectAdditionalCode)
{
var messageSymbol = new Symbol(messageSegment.ToReference());
var code = new DataSegment(new byte[]
{
/* 00000000: */ 0x68, 0x00, 0x00, 0x00, 0x00, // push &message
/* 00000005: */ 0xFF, 0x15, 0x00, 0x00, 0x00, 0x00, // call [&puts]
/* 0000000B: */ 0x83, 0xC4, 0x04, // add esp, 4
}).AsPatchedSegment()
.Patch(1, AddressFixupType.Absolute32BitAddress, messageSymbol)
.Patch(7, AddressFixupType.Absolute32BitAddress, puts);
image.Relocations.Add(new BaseRelocation(RelocationType.HighLow, code.ToReference(1)));
image.Relocations.Add(new BaseRelocation(RelocationType.HighLow, code.ToReference(7)));
result.Add(code);
}
var bootstrapper = Platform.Get(image.MachineType).CreateThunkStub(corDllMain);
foreach (var reloc in bootstrapper.Relocations)
image.Relocations.Add(reloc);
result.Add(bootstrapper.Segment);
if (injectAdditionalCode)
result.Add(messageSegment);
return result;
}
private static ISegment InjectNativeClrEntryPoint(string prefix, IPEImage image)
{
var puts = image
.Imports.First(m => m.Name == "ucrtbase.dll")
.Symbols.First(s => s.Name == "puts");
image.DotNetDirectory!.Flags |= DotNetDirectoryFlags.NativeEntryPoint;
var code = new DataSegment(new byte[]
{
/* 00000000: */ 0x68, 0x00, 0x00, 0x00, 0x00, // push &message
/* 00000005: */ 0xFF, 0x15, 0x00, 0x00, 0x00, 0x00, // call [&puts]
/* 0000000B: */ 0x83, 0xC4, 0x04, // add esp, 4
/* 0000000E: */ 0xB8, 0x01, 0x00, 0x00, 0x00, // mov eax, 1
/* 00000013: */ 0xC2, 0x0c, 0x00, // ret 0xc
/* 00000016: */ // message:
}.Concat(Encoding.ASCII.GetBytes($"[{prefix}]: Unmanaged Entry Point from CLR directory")).ToArray())
.AsPatchedSegment()
.Patch(relativeOffset: 0x1, AddressFixupType.Absolute32BitAddress, symbolOffset: +0x16 /* &message */)
.Patch(relativeOffset: 0x7, AddressFixupType.Absolute32BitAddress, puts)
;
image.Relocations.Add(new BaseRelocation(RelocationType.HighLow, code.ToReference(0x1)));
image.Relocations.Add(new BaseRelocation(RelocationType.HighLow, code.ToReference(0x7)));
return code;
}
private static ISegment InjectTlsCallBack(Utf8String prefix, IPEImage image)
{
var puts = image
.Imports.First(m => m.Name == "ucrtbase.dll")
.Symbols.First(s => s.Name == "puts");
var code = new DataSegment(new byte[]
{
/* 00000000: */ 0x68, 0x00, 0x00, 0x00, 0x00, // push &message
/* 00000005: */ 0xFF, 0x15, 0x00, 0x00, 0x00, 0x00, // call [&puts]
/* 0000000B: */ 0x83, 0xC4, 0x04, // add esp, 4
/* 0000000E: */ 0xC2, 0x0C, 0x00 // ret 0xC
// 00000011: message:
}.Concat(Encoding.ASCII.GetBytes($"[{prefix}]: TLS Callback")).ToArray())
.AsPatchedSegment()
.Patch(0x1, AddressFixupType.Absolute32BitAddress, 0x11)
.Patch(0x7, AddressFixupType.Absolute32BitAddress, puts);
image.Relocations.Add(new BaseRelocation(RelocationType.HighLow, code.ToReference(0x1)));
image.Relocations.Add(new BaseRelocation(RelocationType.HighLow, code.ToReference(0x7)));
var templateBlock = new DataSegment(new byte[100]);
var indexBlock = new DataSegment(new byte[8]);
image.TlsDirectory = new TlsDirectory
{
TemplateData = templateBlock,
Index = indexBlock.ToReference(),
CallbackFunctions = { code.ToReference() },
Characteristics = TlsCharacteristics.Align4Bytes
};
foreach (var reloc in image.TlsDirectory.GetRequiredBaseRelocations())
image.Relocations.Add(reloc);
return new SegmentBuilder
{
{ code, 8 },
{ templateBlock, 8 },
{ indexBlock, 8 },
{ image.TlsDirectory, 8 },
{ image.TlsDirectory.CallbackFunctions, 8 }
};
}
private static RelocationsDirectoryBuffer RebuildBaseRelocations(IPEImage image, PEFile file)
{
var buffer = new RelocationsDirectoryBuffer();
foreach (var reloc in image.Relocations)
buffer.Add(reloc);
file.Sections.Add(new PESection(
".reloc",
SectionFlags.MemoryRead | SectionFlags.ContentInitializedData,
buffer));
return buffer;
}
private static ImportDirectoryBuffer RebuildImportDirectories(IPEImage image, PEFile file)
{
var buffer = new ImportDirectoryBuffer(image.PEKind == OptionalHeaderMagic.PE32);
foreach (var module in image.Imports)
buffer.AddModule(module);
file.Sections.Add(new PESection(
".idata2",
SectionFlags.MemoryRead | SectionFlags.ContentInitializedData,
new SegmentBuilder
{
{buffer, 8},
{buffer.ImportAddressDirectory, 8},
}));
return buffer;
}
}
// Global using directives
global using AsmResolver;
global using AsmResolver.DotNet;
global using AsmResolver.DotNet.Signatures;
global using AsmResolver.PE;
global using AsmResolver.PE.Code;
global using AsmResolver.PE.DotNet;
global using AsmResolver.PE.DotNet.Builder;
global using AsmResolver.PE.DotNet.Cil;
global using AsmResolver.PE.File;
global using AsmResolver.PE.File.Headers;
global using AsmResolver.PE.Imports;
global using AsmResolver.PE.Imports.Builder;
global using AsmResolver.PE.Platforms;
global using AsmResolver.PE.Relocations;
global using AsmResolver.PE.Relocations.Builder;
global using AsmResolver.PE.Tls;
string exePath = @"Z:\input\EntryPoints.exe";
string exeOutputPath = Path.Combine(Path.GetDirectoryName(exePath)!, "patched", Path.GetFileName(exePath));
string dllPath = @"Z:\input\ClassLibrary.dll";
string dllOutputPath = Path.Combine(Path.GetDirectoryName(dllPath)!, "patched", Path.GetFileName(dllPath));
var exeModule = ModuleDefinition.FromFile(exePath);
var dllModule = ModuleDefinition.FromFile(dllPath);
EntryPointsInjector.InjectModuleCctor(exeModule.Name!, exeModule);
// Inject reference to MyClass::Test() in <Module>::.cctor
var instructions = exeModule.GetModuleConstructor()!.CilMethodBody!.Instructions;
instructions.InsertRange(instructions.Count - 1, new[]
{
new CilInstruction(CilOpCodes.Call, dllModule
.TopLevelTypes.First(t => t.Name == "MyClass")
.Methods.First(m => m.Name == "Test")
.ImportWith(exeModule.DefaultImporter))
});
var exeFile = EntryPointsInjector.InjectEntryPoints(exeModule.Name!, exeModule,
EntryPoints.TlsCallback | EntryPoints.ExternalDll);
exeFile.Write(exeOutputPath);
var dllFile = EntryPointsInjector.InjectEntryPoints(dllModule.Name!, dllModule,
EntryPoints.ModuleCctor
| EntryPoints.NativeClrEntryPoint
| EntryPoints.NativePEEntryPoint
| EntryPoints.TlsCallback);
dllFile.Write(dllOutputPath);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment