Introduction
Sometimes is necessary to know which module is calling your exported function. It could be that you want to verify that the caller module is certified by you, or your function should return a different result based on the calling context or whatnot.
Background
My article is based on Chavdar Dimitrov's article that works great for native code only. His example did not address the .NET caller situation, and the old DLLs would make the release version crash, probably due to tightened Windows security. Therefore, I brushed it a bit, and made it fulfill these goals, albeit with reduced functionality.
Digging the code
The function is written in C++ to allow easy calls from both managed and unmanaged code. I've provided DLLs and EXEs for various situations that can occur. Also, it's mandatory to have the .NET 2.0 runtime installed since stackdumper.dll is a mixed DLL, where certain files are compiled with the /clr option. The external code will call GetCallingModulePath
, an entry point in the stackdumper.dll. For some reason, it is important that this function has an argument.
const char* _stdcall GetCallingModulePath(int arg)
{
long reg_ebd;
__asm{
mov eax, ebp
mov reg_ebd, eax
}
ADDR callerAddr;
unsigned i = 0;
HANDLE h = GetCurrentProcess();
module.empty();
callerAddr = GetCallerAddr(reg_ebd);
if (callerAddr == 0)
goto last;
if(getFuncInfo(callerAddr,module) > 0)
{
BOOL bnet = IsDotNetRuntime((char *)module.c_str());
if(bnet)
{
BOOL bres = GetDotNetCallerFileName(module);
if(bres == TRUE)
goto last;
}
}
long temp = 0;
SIZE_T cnt;
long* p = (long*)reg_ebd;
BOOL bres = ReadProcessMemory(h,(LPCVOID)p,(LPVOID)&temp,sizeof(long),&cnt);
reg_ebd = temp;
i++;
last:
.......
If you want to understand how stack tracing works in native code, you should start with Dimitrov's article mentioned above.
Points of interest
The GetDotNetCallerFileName
managed function will skip all the .NET runtime assemblies in the stack, returning the assembly that really called your function.
int GetDotNetCallerFileName(string& module)
{
int res = FALSE;
try
{
Assembly^ callerAssembly = Assembly::GetCallingAssembly();
if(callerAssembly == nullptr)
return FALSE;
String^ sysdir =
System::Runtime::InteropServices::RuntimeEnvironment::GetRuntimeDirectory();
String^ strCallerPath = callerAssembly->Location;
String^ directoryName = nullptr;
if(strCallerPath != nullptr)
directoryName = Path::GetDirectoryName( strCallerPath ) + "\\";
while(directoryName != nullptr
&& (String::Compare(sysdir,directoryName,true) == 0)
|| callerAssembly == Assembly::GetExecutingAssembly())
{
strCallerPath = nullptr;
callerAssembly = callerAssembly->GetCallingAssembly();
directoryName = nullptr;
if(strCallerPath != nullptr)
directoryName = Path::GetDirectoryName( strCallerPath ) + "\\";
strCallerPath = callerAssembly->Location;
}
.........
Calling the code
For native calls, everything is straightforward for direct and indirect calls:
const char *szc = GetCallingModulePath(1);
printf("expected result: test.exe\nfinal result = %s\n\n", szc);
szc = NativeDllCall(1);
printf("expected result: Nativedll.dll\nfinal result = %s\n", szc);
char** pszc = NULL, **original = NULL;
int cnt = GetModuleStackTraceFromNative(pszc);
original = pszc;
printf("\nstack trace---------\n");
while(cnt-- > 0)
{
printf("Trace: = %s\n", *pszc);
CoTaskMemFree(*pszc);
pszc++;
}
CoTaskMemFree(original);
There is no need to call CoTaskMemFree
after GetCallingModulePath
because the returned pointer is a global variable in stackdumper.dll. If calling from managed code( like C#), you have to get the method as an extern
import:
[DllImport("stackdumper.dll", CharSet = CharSet.Ansi)]
static extern string GetCallingModulePath(int arg);
and then invoke it as a static method.