Created
January 13, 2017 05:48
-
-
Save maxmcguire/b7ed7954d6271bac2b49ddea2fc3d87a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
void CorrectFileCase(const wchar_t* srcFileName, wchar_t* dstFileName, int maxLength) | |
{ | |
struct FILE_NAME_INFORMATION | |
{ | |
ULONG FileNameLength; | |
WCHAR FileName[1024 + 1]; | |
}; | |
typedef NTSTATUS (NTAPI *_NtQueryInformationFile)(HANDLE, PIO_STATUS_BLOCK, PVOID, ULONG, FILE_INFORMATION_CLASS); | |
bool useFallback = false; | |
DWORD flagsAndAttributes = FILE_FLAG_BACKUP_SEMANTICS; // Allows us to open directories. | |
HANDLE hFile = CreateFile(srcFileName, GENERIC_READ, FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, nullptr, OPEN_EXISTING, flagsAndAttributes, nullptr); | |
bool result = false; | |
if (hFile != INVALID_HANDLE_VALUE) | |
{ | |
static _NtQueryInformationFile NtQueryInformationFile = nullptr; | |
if (NtQueryInformationFile == nullptr) | |
{ | |
HMODULE hDll = LoadLibraryW(L"ntdll.dll"); | |
NtQueryInformationFile = (_NtQueryInformationFile)GetProcAddress(hDll, "NtQueryInformationFile"); | |
if (NtQueryInformationFile == nullptr) | |
{ | |
Log_Error(2, "Couldn't get NtQueryInformationFile function"); | |
} | |
} | |
IO_STATUS_BLOCK iosb; | |
FILE_NAME_INFORMATION nameInformation; | |
NTSTATUS status = NtQueryInformationFile(hFile, &iosb, &nameInformation, | |
sizeof(nameInformation), (FILE_INFORMATION_CLASS)9); // FileNameInformation | |
CloseHandle(hFile); | |
if (status == 0) | |
{ | |
nameInformation.FileName[nameInformation.FileNameLength / sizeof(WCHAR)] = 0; | |
// Fix up the slashes. | |
for (int i = 0; nameInformation.FileName[i] != 0; ++i) | |
{ | |
if (nameInformation.FileName[i] == L'\\') | |
{ | |
nameInformation.FileName[i] = L'/'; | |
} | |
} | |
dstFileName[0] = 0; | |
int length = 0; | |
// We don't get the volume label, so just use that from the original file name. | |
// There are ways of getting the proper volume label case, but it's more expensive and complex. | |
const WCHAR* volumeLabelEnd = wcschr(srcFileName, L':'); | |
if (volumeLabelEnd != nullptr) | |
{ | |
length = volumeLabelEnd - srcFileName + 1; | |
wcsncpy(dstFileName, srcFileName, volumeLabelEnd - srcFileName + 1); | |
} | |
int appendLength = Min(maxLength - length - 1, nameInformation.FileNameLength); | |
wcsncpy(dstFileName + length, nameInformation.FileName, appendLength); | |
dstFileName[appendLength] = 0; | |
int srcLength = wcslen(srcFileName); | |
int dstLength = wcslen(dstFileName); | |
if (srcLength > 0 && srcFileName[srcLength - 1] == L'/' || srcFileName[srcLength - 1] == L'\\') | |
{ | |
// Make sure we have a trailing slash. | |
if (dstLength > 0 && dstFileName[dstLength - 1] != '/') | |
{ | |
if (dstLength + 1 < maxLength) | |
{ | |
dstFileName[dstLength] = '/'; | |
dstFileName[dstLength + 1] = 0; | |
} | |
} | |
} | |
} | |
else | |
{ | |
Log_Message(2, "Using fallback method for testing file name case"); | |
useFallback = true; | |
} | |
} | |
else if (GetLastError() != ERROR_FILE_NOT_FOUND) | |
{ | |
useFallback = true; | |
} | |
if (useFallback) | |
{ | |
wchar_t shortFileName[1025]; | |
if (GetShortPathName(srcFileName, shortFileName, countof(shortFileName) - 1) != 0) | |
{ | |
GetLongPathName(shortFileName, dstFileName, maxLength - 1); | |
} | |
else | |
{ | |
wcsncpy(dstFileName, srcFileName, maxLength - 1); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment