From 88618e2a4ae398975a528d3a4ad94be63a630cac Mon Sep 17 00:00:00 2001 From: Joachim Bauch Date: Tue, 1 Feb 2005 23:04:37 +0000 Subject: [PATCH] started writing some PEB compatible code --- MemoryModule.c | 404 +++++++++++++++++------------ MemoryModule.h | 11 +- example/DllLoader/DllLoader.cpp | 81 +++++- example/DllLoader/DllLoader.vcproj | 7 +- example/DllMemory.sln | 8 + example/SampleDLL/SampleDLL.cpp | 12 + 6 files changed, 351 insertions(+), 172 deletions(-) diff --git a/MemoryModule.c b/MemoryModule.c index 10a6c30..56ac062 100644 --- a/MemoryModule.c +++ b/MemoryModule.c @@ -26,23 +26,24 @@ #include #include +#include "ntinternals.h" + +#define DEBUG_OUTPUT 0 + #ifdef DEBUG_OUTPUT #include #endif #include "MemoryModule.h" -typedef struct { - PIMAGE_NT_HEADERS headers; - unsigned char *codeBase; - HMODULE *modules; - int numModules; - int initialized; -} MEMORYMODULE, *PMEMORYMODULE; - typedef BOOL (WINAPI *DllEntryProc)(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved); -#define GET_HEADER_DICTIONARY(module, idx) &(module)->headers->OptionalHeader.DataDirectory[idx] +#define GET_NT_HEADER(module) ((PIMAGE_NT_HEADERS)&((const unsigned char *)(module))[((PIMAGE_DOS_HEADER)(module))->e_lfanew]) +#define GET_HEADER_DICTIONARY(module, idx) &GET_NT_HEADER(module)->OptionalHeader.DataDirectory[idx] +#define CALCULATE_REAL_ADDRESS(base, offset) (((unsigned char *)(base)) + (offset)) + +// stores number of modules loaded +static DWORD ModuleCount = 0; #ifdef DEBUG_OUTPUT static void @@ -61,13 +62,14 @@ OutputLastError(const char *msg) #endif static void -CopySections(const unsigned char *data, PIMAGE_NT_HEADERS old_headers, PMEMORYMODULE module) +CopySections(const unsigned char *data, PIMAGE_NT_HEADERS old_headers, HMODULE module) { - int i, size; - unsigned char *codeBase = module->codeBase; - unsigned char *dest; - PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(module->headers); - for (i=0; iheaders->FileHeader.NumberOfSections; i++, section++) + DWORD i, size; + LPVOID dest; + PIMAGE_SECTION_HEADER section; + + section = IMAGE_FIRST_SECTION(GET_NT_HEADER(module)); + for (i=0; iFileHeader.NumberOfSections; i++, section++) { if (section->SizeOfRawData == 0) { @@ -76,7 +78,7 @@ CopySections(const unsigned char *data, PIMAGE_NT_HEADERS old_headers, PMEMORYMO size = old_headers->OptionalHeader.SectionAlignment; if (size > 0) { - dest = (unsigned char *)VirtualAlloc(codeBase + section->VirtualAddress, + dest = VirtualAlloc(CALCULATE_REAL_ADDRESS(module, section->VirtualAddress), size, MEM_COMMIT, PAGE_READWRITE); @@ -90,7 +92,7 @@ CopySections(const unsigned char *data, PIMAGE_NT_HEADERS old_headers, PMEMORYMO } // commit memory block and copy data from dll - dest = (unsigned char *)VirtualAlloc(codeBase + section->VirtualAddress, + dest = VirtualAlloc(CALCULATE_REAL_ADDRESS(module, section->VirtualAddress), section->SizeOfRawData, MEM_COMMIT, PAGE_READWRITE); @@ -113,13 +115,14 @@ static int ProtectionFlags[2][2][2] = { }; static void -FinalizeSections(PMEMORYMODULE module) +FinalizeSections(HMODULE module) { - int i; - PIMAGE_SECTION_HEADER section = IMAGE_FIRST_SECTION(module->headers); + DWORD i; + PIMAGE_SECTION_HEADER section; // loop through all sections and change access flags - for (i=0; iheaders->FileHeader.NumberOfSections; i++, section++) + section = IMAGE_FIRST_SECTION(GET_NT_HEADER(module)); + for (i=0; iFileHeader.NumberOfSections; i++, section++) { DWORD protect, oldProtect, size; int executable = (section->Characteristics & IMAGE_SCN_MEM_EXECUTE) != 0; @@ -143,9 +146,9 @@ FinalizeSections(PMEMORYMODULE module) if (size == 0) { if (section->Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) - size = module->headers->OptionalHeader.SizeOfInitializedData; + size = GET_NT_HEADER(module)->OptionalHeader.SizeOfInitializedData; else if (section->Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) - size = module->headers->OptionalHeader.SizeOfUninitializedData; + size = GET_NT_HEADER(module)->OptionalHeader.SizeOfUninitializedData; } if (size > 0) @@ -161,19 +164,19 @@ FinalizeSections(PMEMORYMODULE module) } static void -PerformBaseRelocation(PMEMORYMODULE module, DWORD delta) +PerformBaseRelocation(HMODULE module, DWORD delta) { DWORD i; - unsigned char *codeBase = module->codeBase; - - PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY(module, IMAGE_DIRECTORY_ENTRY_BASERELOC); + PIMAGE_DATA_DIRECTORY directory; + + directory = GET_HEADER_DICTIONARY(module, IMAGE_DIRECTORY_ENTRY_BASERELOC); if (directory->Size > 0) { - PIMAGE_BASE_RELOCATION relocation = (PIMAGE_BASE_RELOCATION)(codeBase + directory->VirtualAddress); + PIMAGE_BASE_RELOCATION relocation = (PIMAGE_BASE_RELOCATION)CALCULATE_REAL_ADDRESS(module, directory->VirtualAddress); for (; relocation->VirtualAddress > 0; ) { - unsigned char *dest = (unsigned char *)(codeBase + relocation->VirtualAddress); - unsigned short *relInfo = (unsigned short *)((unsigned char *)relocation + IMAGE_SIZEOF_BASE_RELOCATION); + unsigned char *dest = (unsigned char *)CALCULATE_REAL_ADDRESS(module, relocation->VirtualAddress); + unsigned short *relInfo = (unsigned short *)CALCULATE_REAL_ADDRESS(relocation, IMAGE_SIZEOF_BASE_RELOCATION); for (i=0; i<((relocation->SizeOfBlock-IMAGE_SIZEOF_BASE_RELOCATION) / 2); i++, relInfo++) { DWORD *patchAddrHL; @@ -209,21 +212,21 @@ PerformBaseRelocation(PMEMORYMODULE module, DWORD delta) } static int -BuildImportTable(PMEMORYMODULE module) +BuildImportTable(HMODULE module) { int result=1; - unsigned char *codeBase = module->codeBase; PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY(module, IMAGE_DIRECTORY_ENTRY_IMPORT); if (directory->Size > 0) { - PIMAGE_IMPORT_DESCRIPTOR importDesc = (PIMAGE_IMPORT_DESCRIPTOR)(codeBase + directory->VirtualAddress); + PIMAGE_IMPORT_DESCRIPTOR importDesc = (PIMAGE_IMPORT_DESCRIPTOR)CALCULATE_REAL_ADDRESS(module, directory->VirtualAddress); for (; !IsBadReadPtr(importDesc, sizeof(IMAGE_IMPORT_DESCRIPTOR)) && importDesc->Name; importDesc++) { DWORD *thunkRef, *funcRef; - HMODULE handle = LoadLibrary((LPCSTR)(codeBase + importDesc->Name)); + HMODULE handle = LoadLibrary((LPCSTR)CALCULATE_REAL_ADDRESS(module, importDesc->Name)); if (handle == INVALID_HANDLE_VALUE) { + SetLastError(ERROR_MOD_NOT_FOUND); #if DEBUG_OUTPUT OutputLastError("Can't load library"); #endif @@ -231,33 +234,26 @@ BuildImportTable(PMEMORYMODULE module) break; } - module->modules = (HMODULE *)realloc(module->modules, (module->numModules+1)*(sizeof(HMODULE))); - if (module->modules == NULL) - { - result = 0; - break; - } - - module->modules[module->numModules++] = handle; if (importDesc->OriginalFirstThunk) { - thunkRef = (DWORD *)(codeBase + importDesc->OriginalFirstThunk); - funcRef = (DWORD *)(codeBase + importDesc->FirstThunk); + thunkRef = (DWORD *)CALCULATE_REAL_ADDRESS(module, importDesc->OriginalFirstThunk); + funcRef = (DWORD *)CALCULATE_REAL_ADDRESS(module, importDesc->FirstThunk); } else { // no hint table - thunkRef = (DWORD *)(codeBase + importDesc->FirstThunk); - funcRef = (DWORD *)(codeBase + importDesc->FirstThunk); + thunkRef = (DWORD *)CALCULATE_REAL_ADDRESS(module, importDesc->FirstThunk); + funcRef = (DWORD *)CALCULATE_REAL_ADDRESS(module, importDesc->FirstThunk); } for (; *thunkRef; thunkRef++, funcRef++) { if IMAGE_SNAP_BY_ORDINAL(*thunkRef) *funcRef = (DWORD)GetProcAddress(handle, (LPCSTR)IMAGE_ORDINAL(*thunkRef)); else { - PIMAGE_IMPORT_BY_NAME thunkData = (PIMAGE_IMPORT_BY_NAME)(codeBase + *thunkRef); + PIMAGE_IMPORT_BY_NAME thunkData = (PIMAGE_IMPORT_BY_NAME)CALCULATE_REAL_ADDRESS(module, *thunkRef); *funcRef = (DWORD)GetProcAddress(handle, (LPCSTR)&thunkData->Name); } if (*funcRef == 0) { + SetLastError(ERROR_PROC_NOT_FOUND); result = 0; break; } @@ -271,15 +267,163 @@ BuildImportTable(PMEMORYMODULE module) return result; } -HMEMORYMODULE MemoryLoadLibrary(const void *data) +static PPEB +GetPEB(void) { - PMEMORYMODULE result; + // XXX: is there a better or even documented way to do this? + __asm { + // get PEB + mov eax, dword ptr fs:[30h] + } +} + +static HMODULE +FindLibraryInPEB(const unsigned char *name, int incLoadCount) +{ + PPEB_LDR_DATA loaderData; + PLDR_MODULE loaderModule; + PWSTR longName; + size_t i; + HMODULE result=NULL; + + if (name == NULL) + return NULL; + + // convert name to long character name + longName = (PWSTR)calloc((strlen(name)+1)*2, 1); + for (i=0; iLoaderData; + loaderModule = (PLDR_MODULE)(loaderData->InLoadOrderModuleList.Flink); + while (1) + { + if (wcsicmp(longName, loaderModule->BaseDllName.Buffer) == 0) + { + result = loaderModule->BaseAddress; + if (incLoadCount && loaderModule->LoadCount != 0xffff) + // we use this module, so increate the load count + loaderModule->LoadCount++; + + goto exit; + } + + // advance to next module + loaderModule = (PLDR_MODULE)(loaderModule->InLoadOrderModuleList.Flink); + if (loaderModule->BaseAddress == NULL || loaderModule == (PLDR_MODULE)(loaderData->InLoadOrderModuleList.Flink)) + // we traversed through the complete list + // and didn't find the library + goto exit; + } + +exit: + free(longName); + + return result; +} + +// Append a loader module to the end of the loader data list of the PEB +#define AppendToChain(module, list, chain) { \ + (module)->##chain##.Flink = (list)->##chain##.Flink; \ + (module)->##chain##.Blink = (list)->##chain##.Blink; \ + ((PLDR_MODULE)((list)->##chain##.Blink))->##chain##.Flink = &(module)->##chain##; \ + (list)->##chain##.Blink = &(module)->##chain##; \ +}; + +static PLDR_MODULE +InsertModuleInPEB(HMODULE module, unsigned char *name, unsigned char *baseName, DWORD locationDelta) +{ + PLDR_MODULE loaderModule; + PPEB_LDR_DATA loaderData = GetPEB()->LoaderData; + DWORD entry = GET_NT_HEADER(module)->OptionalHeader.AddressOfEntryPoint; + size_t i; + + loaderModule = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, sizeof(LDR_MODULE)); + if (loaderModule == NULL) + return NULL; + + loaderModule->BaseAddress = module; + loaderModule->EntryPoint = (PVOID)(entry ? CALCULATE_REAL_ADDRESS(module, entry) : 0); + loaderModule->SizeOfImage = GET_NT_HEADER(module)->OptionalHeader.SizeOfImage; + loaderModule->LoadCount = 1; + + loaderModule->BaseDllName.Buffer = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, (strlen(baseName)+1)*2); + if (loaderModule->BaseDllName.Buffer == NULL) + { + HeapFree(GetProcessHeap(), 0, loaderModule); + return NULL; + } + loaderModule->BaseDllName.Length = (USHORT)strlen(baseName)*2; + loaderModule->BaseDllName.MaximumLength = (USHORT)HeapSize(GetProcessHeap(), 0, loaderModule->BaseDllName.Buffer); + for (i=0; iBaseDllName.Buffer[i] = baseName[i]; + + loaderModule->FullDllName.Buffer = HeapAlloc(GetProcessHeap(), HEAP_ZERO_MEMORY, (strlen(name)+1)*2); + if (loaderModule->BaseDllName.Buffer == NULL) + { + HeapFree(GetProcessHeap(), 0, loaderModule->BaseDllName.Buffer); + HeapFree(GetProcessHeap(), 0, loaderModule); + return NULL; + } + loaderModule->FullDllName.Length = (USHORT)strlen(name)*2; + loaderModule->FullDllName.MaximumLength = (USHORT)HeapSize(GetProcessHeap(), 0, loaderModule->FullDllName.Buffer); + for (i=0; iFullDllName.Buffer[i] = name[i]; + + // XXX: are these the correct flags? + loaderModule->Flags = IMAGE_DLL | ENTRY_PROCESSED; + if (locationDelta != 0) + loaderModule->Flags |= IMAGE_NOT_AT_BASE; + loaderModule->TimeDateStamp = GET_NT_HEADER(module)->FileHeader.TimeDateStamp; + + // XXX: do we need more set the hash table? + //loaderModule->HashTableEntry.Flink = &loaderModule->HashTableEntry; + //loaderModule->HashTableEntry.Blink = &loaderModule->HashTableEntry; + + AppendToChain(loaderModule, loaderData, InLoadOrderModuleList); + AppendToChain(loaderModule, loaderData, InInitializationOrderModuleList); + + // XXX: insert at the correct position in the chain + AppendToChain(loaderModule, loaderData, InMemoryOrderModuleList); + return loaderModule; +} + +HMODULE MemoryLoadLibrary(const void *data, unsigned char *name) +{ + HMODULE result; PIMAGE_DOS_HEADER dos_header; PIMAGE_NT_HEADERS old_header; - unsigned char *code, *headers; + LPVOID headers; + unsigned char *baseName; + unsigned char fullname[MAX_DLL_NAME_LENGTH], tempname[MAX_DLL_NAME_LENGTH]; DWORD locationDelta; DllEntryProc DllEntry; BOOL successfull; + DWORD hasFullName; + PLDR_MODULE loaderModule=NULL; + + // make sure we have a module name + if (name == NULL || strlen(name) == 0) + { + sprintf(tempname, "memorymodule%d", ModuleCount); + name = (unsigned char *)&tempname; + } + + // maybe a module with the given name has been loaded already + hasFullName = GetFullPathName(name, sizeof(fullname), (LPSTR)&fullname, &baseName); + + // search for module in PEB + result = FindLibraryInPEB(hasFullName ? baseName : name, 1); + if (result != NULL) + // already loaded this module + goto exit; + + if (hasFullName) + // use complete filename as module name + name = (unsigned char *)&fullname; + else + baseName = name; dos_header = (PIMAGE_DOS_HEADER)data; if (dos_header->e_magic != IMAGE_DOS_SIGNATURE) @@ -287,70 +431,61 @@ HMEMORYMODULE MemoryLoadLibrary(const void *data) #if DEBUG_OUTPUT OutputDebugString("Not a valid executable file.\n"); #endif - return NULL; + goto error; } - old_header = (PIMAGE_NT_HEADERS)&((const unsigned char *)(data))[dos_header->e_lfanew]; + old_header = GET_NT_HEADER(data); if (old_header->Signature != IMAGE_NT_SIGNATURE) { #if DEBUG_OUTPUT OutputDebugString("No PE header found.\n"); #endif - return NULL; + goto error; } // reserve memory for image of library - code = (unsigned char *)VirtualAlloc((LPVOID)(old_header->OptionalHeader.ImageBase), + result = (HMODULE)VirtualAlloc((LPVOID)(old_header->OptionalHeader.ImageBase), old_header->OptionalHeader.SizeOfImage, MEM_RESERVE, PAGE_READWRITE); - if (code == NULL) + if (result == NULL) // try to allocate memory at arbitrary position - code = (unsigned char *)VirtualAlloc(NULL, + result = (HMODULE)VirtualAlloc(NULL, old_header->OptionalHeader.SizeOfImage, MEM_RESERVE, PAGE_READWRITE); - if (code == NULL) + if (result == NULL) { -#if DEBUG_OUTPUT - OutputLastError("Can't reserve memory"); -#endif - return NULL; + SetLastError(ERROR_OUTOFMEMORY); + goto error; } - result = (PMEMORYMODULE)HeapAlloc(GetProcessHeap(), 0, sizeof(MEMORYMODULE)); - result->codeBase = code; - result->numModules = 0; - result->modules = NULL; - result->initialized = 0; - // XXX: is it correct to commit the complete memory region at once? // calling DllEntry raises an exception if we don't... - VirtualAlloc(code, + VirtualAlloc(result, old_header->OptionalHeader.SizeOfImage, MEM_COMMIT, PAGE_READWRITE); // commit memory for headers - headers = (unsigned char *)VirtualAlloc(code, - old_header->OptionalHeader.SizeOfHeaders, + headers = VirtualAlloc(result, + dos_header->e_lfanew + old_header->OptionalHeader.SizeOfHeaders, MEM_COMMIT, PAGE_READWRITE); // copy PE header to code memcpy(headers, dos_header, dos_header->e_lfanew + old_header->OptionalHeader.SizeOfHeaders); - result->headers = (PIMAGE_NT_HEADERS)&((const unsigned char *)(headers))[dos_header->e_lfanew]; - + // update position - result->headers->OptionalHeader.ImageBase = (DWORD)code; + GET_NT_HEADER(result)->OptionalHeader.ImageBase = (DWORD)result; // copy sections from DLL file block to new memory location CopySections(data, old_header, result); // adjust base address of imported data - locationDelta = (DWORD)(code - old_header->OptionalHeader.ImageBase); + locationDelta = (DWORD)((DWORD)result - old_header->OptionalHeader.ImageBase); if (locationDelta != 0) PerformBaseRelocation(result, locationDelta); @@ -362,106 +497,53 @@ HMEMORYMODULE MemoryLoadLibrary(const void *data) // sections that are marked as "discardable" FinalizeSections(result); + // Add loaded module to PEB + if (!(loaderModule = InsertModuleInPEB(result, name, baseName, locationDelta))) + goto error; + // get entry point of loaded library - if (result->headers->OptionalHeader.AddressOfEntryPoint != 0) + if (GET_NT_HEADER(result)->OptionalHeader.AddressOfEntryPoint != 0) { - DllEntry = (DllEntryProc)(code + result->headers->OptionalHeader.AddressOfEntryPoint); - if (DllEntry == 0) - { -#if DEBUG_OUTPUT - OutputDebugString("Library has no entry point.\n"); -#endif - goto error; - } - // notify library about attaching to process - successfull = (*DllEntry)((HINSTANCE)code, DLL_PROCESS_ATTACH, 0); + DllEntry = (DllEntryProc)CALCULATE_REAL_ADDRESS(result, GET_NT_HEADER(result)->OptionalHeader.AddressOfEntryPoint); + successfull = (*DllEntry)(result, DLL_PROCESS_ATTACH, 0); if (!successfull) { -#if DEBUG_OUTPUT - OutputDebugString("Can't attach library.\n"); -#endif + SetLastError(ERROR_DLL_INIT_FAILED); goto error; } - result->initialized = 1; + loaderModule->Flags |= PROCESS_ATTACH_CALLED; } - return (HMEMORYMODULE)result; + ModuleCount++; + goto exit; error: - // cleanup - MemoryFreeLibrary(result); - return NULL; -} - -FARPROC MemoryGetProcAddress(HMEMORYMODULE module, const char *name) -{ - unsigned char *codeBase = ((PMEMORYMODULE)module)->codeBase; - int idx=-1; - DWORD i, *nameRef; - WORD *ordinal; - PIMAGE_EXPORT_DIRECTORY exports; - PIMAGE_DATA_DIRECTORY directory = GET_HEADER_DICTIONARY((PMEMORYMODULE)module, IMAGE_DIRECTORY_ENTRY_EXPORT); - if (directory->Size == 0) - // no export table found - return NULL; - - exports = (PIMAGE_EXPORT_DIRECTORY)(codeBase + directory->VirtualAddress); - if (exports->NumberOfNames == 0 || exports->NumberOfFunctions == 0) - // DLL doesn't export anything - return NULL; - - // search function name in list of exported names - nameRef = (DWORD *)(codeBase + exports->AddressOfNames); - ordinal = (WORD *)(codeBase + exports->AddressOfNameOrdinals); - for (i=0; iNumberOfNames; i++, nameRef++, ordinal++) - if (stricmp(name, (const char *)(codeBase + *nameRef)) == 0) - { - idx = *ordinal; - break; - } - - if (idx == -1) - // exported symbol not found - return NULL; - - if ((DWORD)idx > exports->NumberOfFunctions) - // name <-> ordinal number don't match - return NULL; - - // AddressOfFunctions contains the RVAs to the "real" functions - return (FARPROC)(codeBase + *(DWORD *)(codeBase + exports->AddressOfFunctions + (idx*4))); -} - -void MemoryFreeLibrary(HMEMORYMODULE mod) -{ - int i; - PMEMORYMODULE module = (PMEMORYMODULE)mod; - - if (module != NULL) + // perform some cleanup... + if (loaderModule != NULL) { - if (module->initialized != 0) - { - // notify library about detaching from process - DllEntryProc DllEntry = (DllEntryProc)(module->codeBase + module->headers->OptionalHeader.AddressOfEntryPoint); - (*DllEntry)((HINSTANCE)module->codeBase, DLL_PROCESS_DETACH, 0); - module->initialized = 0; - } - - if (module->modules != NULL) - { - // free previously opened libraries - for (i=0; inumModules; i++) - if (module->modules[i] != INVALID_HANDLE_VALUE) - FreeLibrary(module->modules[i]); + if ((loaderModule->Flags & PROCESS_ATTACH_CALLED) != 0) + (*DllEntry)(result, DLL_PROCESS_DETACH, 0); + + // remove from module chains + loaderModule->InInitializationOrderModuleList.Flink->Blink = loaderModule->InInitializationOrderModuleList.Blink; + loaderModule->InInitializationOrderModuleList.Blink->Flink = loaderModule->InInitializationOrderModuleList.Flink; + loaderModule->InLoadOrderModuleList.Flink->Blink = loaderModule->InLoadOrderModuleList.Blink; + loaderModule->InLoadOrderModuleList.Blink->Flink = loaderModule->InLoadOrderModuleList.Flink; + loaderModule->InMemoryOrderModuleList.Flink->Blink = loaderModule->InMemoryOrderModuleList.Blink; + loaderModule->InMemoryOrderModuleList.Blink->Flink = loaderModule->InMemoryOrderModuleList.Flink; + + // free memory for PEB structures + HeapFree(GetProcessHeap(), 0, loaderModule->BaseDllName.Buffer); + HeapFree(GetProcessHeap(), 0, loaderModule->FullDllName.Buffer); + HeapFree(GetProcessHeap(), 0, loaderModule); + } - free(module->modules); - } + if (result != NULL) + VirtualFree(result, 0, MEM_RELEASE); - if (module->codeBase != NULL) - // release memory of library - VirtualFree(module->codeBase, 0, MEM_RELEASE); + result = NULL; - HeapFree(GetProcessHeap(), 0, module); - } +exit: + return result; } diff --git a/MemoryModule.h b/MemoryModule.h index 15a7b8e..9dd4c31 100644 --- a/MemoryModule.h +++ b/MemoryModule.h @@ -26,17 +26,16 @@ #include -typedef void *HMEMORYMODULE; - #ifdef __cplusplus extern "C" { #endif -HMEMORYMODULE MemoryLoadLibrary(const void *); - -FARPROC MemoryGetProcAddress(HMEMORYMODULE, const char *); +HMODULE MemoryLoadLibrary(const void *, unsigned char *); -void MemoryFreeLibrary(HMEMORYMODULE); +// backwards compatibility +#define HMEMORYMODULE HMODULE +#define MemoryGetProcAddress GetProcAddress +#define MemoryFreeLibrary FreeLibrary #ifdef __cplusplus } diff --git a/example/DllLoader/DllLoader.cpp b/example/DllLoader/DllLoader.cpp index 4fa8ca3..743434d 100644 --- a/example/DllLoader/DllLoader.cpp +++ b/example/DllLoader/DllLoader.cpp @@ -4,22 +4,92 @@ #include #include +#include "../../ntinternals.h" #include "../../MemoryModule.h" typedef int (*addNumberProc)(int, int); #define DLL_FILE "..\\..\\SampleDLL\\Debug\\SampleDLL.dll" +PPEB GetPEB(void) +{ + // XXX: is there a better or even documented way to do this? + __asm { + // get PEB + mov eax, dword ptr fs:[30h] + } +} + +void DumpPEB(void) +{ + PPEB peb = GetPEB(); + PPEB_LDR_DATA loaderData = peb->LoaderData; + PLDR_MODULE loaderModule; + + printf("-------------------------------------\n"); + printf("PEB at 0x%x\n", (DWORD)peb); + printf("Modules (Load Order)\n"); + loaderModule = (PLDR_MODULE)(loaderData->InLoadOrderModuleList.Flink); + printf("Last: %x\n", (DWORD)loaderData->InLoadOrderModuleList.Blink); + while (1) + { + printf("Info: %x\n", (DWORD)loaderModule); + if (!IsBadReadPtr(loaderModule->BaseDllName.Buffer, loaderModule->BaseDllName.Length)) + { + wprintf(L"Library: %s\n", loaderModule->BaseDllName.Buffer); + wprintf(L"Fullname: %s\n", loaderModule->FullDllName.Buffer); + } else + printf("Unknown module\n"); + printf("Address: %x\n", (DWORD)loaderModule->BaseAddress); + printf("Load count: %d\n", loaderModule->LoadCount); + printf("Flags: %d\n", loaderModule->Flags); + printf("Size: %d\n", loaderModule->SizeOfImage); + printf("Entry: %x\n", (DWORD)loaderModule->EntryPoint); + printf("Hash: %x %x\n", (DWORD)loaderModule->HashTableEntry.Flink, (DWORD)loaderModule->HashTableEntry.Blink); + + // advance to next module + loaderModule = (PLDR_MODULE)(loaderModule->InLoadOrderModuleList.Flink); + if (loaderModule == (PLDR_MODULE)(loaderData->InLoadOrderModuleList.Flink)) + // we traversed through the complete list + // and didn't find the library + goto exit; + + printf("\n"); + } + +exit: + printf("=====================================\n"); + printf("\n"); +} + +static void +OutputLastError(const char *msg) +{ + LPVOID tmp; + char *tmpmsg; + FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPTSTR)&tmp, 0, NULL); + tmpmsg = (char *)malloc(strlen(msg) + strlen((const char *)tmp) + 3); + sprintf(tmpmsg, "%s: %s", msg, tmp); + OutputDebugString(tmpmsg); + free(tmpmsg); + LocalFree(tmp); +} + void LoadFromFile(void) { addNumberProc addNumber; - HINSTANCE handle = LoadLibrary(DLL_FILE); + HINSTANCE handle; + DumpPEB(); + handle = LoadLibrary(DLL_FILE); if (handle == INVALID_HANDLE_VALUE) return; + DumpPEB(); addNumber = (addNumberProc)GetProcAddress(handle, "addNumbers"); printf("From file: %d\n", addNumber(1, 2)); FreeLibrary(handle); + DumpPEB(); } void LoadFromMemory(void) @@ -44,15 +114,20 @@ void LoadFromMemory(void) fread(data, 1, size, fp); fclose(fp); - module = MemoryLoadLibrary(data); + DumpPEB(); + module = MemoryLoadLibrary(data, (unsigned char *)DLL_FILE); if (module == NULL) { printf("Can't load library from memory.\n"); goto exit; } + DumpPEB(); addNumber = (addNumberProc)MemoryGetProcAddress(module, "addNumbers"); - printf("From memory: %d\n", addNumber(1, 2)); + if (addNumber) + printf("From memory: %d\n", addNumber(1, 2)); + else + printf("Not found\n"); MemoryFreeLibrary(module); exit: diff --git a/example/DllLoader/DllLoader.vcproj b/example/DllLoader/DllLoader.vcproj index e6e24dd..ddde550 100644 --- a/example/DllLoader/DllLoader.vcproj +++ b/example/DllLoader/DllLoader.vcproj @@ -19,7 +19,7 @@ + + +#include + #include "SampleDLL.h" extern "C" { +BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpvReserved) +{ + printf("DllMain called:\n"); + printf("Instance: %x\n", (DWORD)hinstDLL); + printf("Reason: %d\n", fdwReason); + printf("Reserved: %x\n", (DWORD)lpvReserved); + return 1; +} + SAMPLEDLL_API int addNumbers(int a, int b) { return a + b;