Click here to Skip to main content
65,938 articles
CodeProject is changing. Read more.
Articles / Languages / C#

Get the calling module that invoked your function

4.38/5 (6 votes)
29 Jan 2007CPOL2 min read 1   847  
How to get the calling module that invoked your function as an entry point.

Sample Image - getcallingmodule.png

Introduction

Sometimes is necessary to know which module is calling your exported function. It could be that you want to verify that the caller module is certified by you, or your function should return a different result based on the calling context or whatnot.

Background

My article is based on Chavdar Dimitrov's article that works great for native code only. His example did not address the .NET caller situation, and the old DLLs would make the release version crash, probably due to tightened Windows security. Therefore, I brushed it a bit, and made it fulfill these goals, albeit with reduced functionality.

Digging the code

The function is written in C++ to allow easy calls from both managed and unmanaged code. I've provided DLLs and EXEs for various situations that can occur. Also, it's mandatory to have the .NET 2.0 runtime installed since stackdumper.dll is a mixed DLL, where certain files are compiled with the /clr option. The external code will call GetCallingModulePath, an entry point in the stackdumper.dll. For some reason, it is important that this function has an argument.

MC++
const char* _stdcall GetCallingModulePath(int arg)
{
  long reg_ebd;
  __asm{
    mov eax, ebp
    mov reg_ebd, eax
  }
  ADDR callerAddr;
  unsigned i = 0;
  HANDLE h = GetCurrentProcess();
    module.empty();      
    callerAddr = GetCallerAddr(reg_ebd);
    if (callerAddr == 0)
        goto last;
      
    if(getFuncInfo(callerAddr,module) > 0)
    {
       BOOL bnet = IsDotNetRuntime((char *)module.c_str());
       if(bnet)
       {
        BOOL bres = GetDotNetCallerFileName(module);
        if(bres == TRUE)
            goto last;
       }
    }
      long temp = 0;
      SIZE_T cnt;
      long* p = (long*)reg_ebd;
      BOOL bres = ReadProcessMemory(h,(LPCVOID)p,(LPVOID)&temp,sizeof(long),&cnt);
      reg_ebd = temp;
      i++;

last:
.......

If you want to understand how stack tracing works in native code, you should start with Dimitrov's article mentioned above.

Points of interest

The GetDotNetCallerFileName managed function will skip all the .NET runtime assemblies in the stack, returning the assembly that really called your function.

MC++
int GetDotNetCallerFileName(string& module)
{
    int res = FALSE;
    try
    {
        Assembly^ callerAssembly = Assembly::GetCallingAssembly();
        if(callerAssembly == nullptr)
            return FALSE;
        String^ sysdir = 
         System::Runtime::InteropServices::RuntimeEnvironment::GetRuntimeDirectory();
        //skip all .net framework assemblies and calls from the same assembly

        String^ strCallerPath = callerAssembly->Location;
        String^ directoryName = nullptr;
        if(strCallerPath != nullptr)
            directoryName = Path::GetDirectoryName( strCallerPath ) + "\\";

        while(directoryName != nullptr 
              && (String::Compare(sysdir,directoryName,true) == 0) 
              || callerAssembly == Assembly::GetExecutingAssembly())
        {
            strCallerPath = nullptr;
            callerAssembly = callerAssembly->GetCallingAssembly();
            directoryName = nullptr;
            if(strCallerPath != nullptr)
                directoryName = Path::GetDirectoryName( strCallerPath ) + "\\";
            strCallerPath = callerAssembly->Location;
        }
.........

Calling the code

For native calls, everything is straightforward for direct and indirect calls:

MC++
const char *szc = GetCallingModulePath(1);
printf("expected result: test.exe\nfinal result = %s\n\n", szc);

szc = NativeDllCall(1);
printf("expected result: Nativedll.dll\nfinal result = %s\n", szc);

char** pszc = NULL, **original = NULL;
int cnt = GetModuleStackTraceFromNative(pszc);
original = pszc;
printf("\nstack trace---------\n");
while(cnt-- > 0)
{
  printf("Trace: = %s\n", *pszc);
  CoTaskMemFree(*pszc);
  pszc++;
}
CoTaskMemFree(original);

There is no need to call CoTaskMemFree after GetCallingModulePath because the returned pointer is a global variable in stackdumper.dll. If calling from managed code( like C#), you have to get the method as an extern import:

C#
[DllImport("stackdumper.dll", CharSet = CharSet.Ansi)]
static extern string GetCallingModulePath(int arg);

and then invoke it as a static method.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)