📄 mstscax.cpp
字号:
#include "stdafx.h"
namespace
{
using namespace MSTSCLib;
typedef HRESULT (STDAPICALLTYPE * PFNDLLGETCLASSOBJECT)(IN REFCLSID rclsid, IN REFIID riid, OUT LPVOID FAR * ppv);
typedef HRESULT (STDAPICALLTYPE * PFNDLLCANUNLOADNOW)(void);
typedef ULONG (STDAPICALLTYPE * PFNDLLGETTSCCTLVER)(void);
PFNDLLGETCLASSOBJECT pfnDllGetClassObject = NULL;
PFNDLLCANUNLOADNOW pfnDllCanUnloadNow = NULL;
PFNDLLGETTSCCTLVER pfnDllGetTscCtlVer = NULL;
HMODULE hmMstscax = NULL;
extern "C" char __ImageBase;
static const HMODULE hmSelf = reinterpret_cast<HMODULE>(&__ImageBase);
void init()
{
if(hmMstscax)
return;
TCHAR szFileName[MAX_PATH + 1];
GetModuleFileName(hmSelf, szFileName, MAX_PATH);
std::basic_string<TCHAR> strFileName(&szFileName[0]);
std::reverse_iterator<std::basic_string<TCHAR>::const_iterator > begin(strFileName.end());
std::reverse_iterator<std::basic_string<TCHAR>::const_iterator > end(strFileName.begin());
std::basic_string<TCHAR>::const_iterator endPath = std::find(begin, end, TEXT('\\')).base();
std::basic_string<TCHAR> strPath(strFileName.begin(), endPath);
strPath.append(TEXT("original\\mstscax.dll"));
hmMstscax = LoadLibrary(strPath.c_str());
pfnDllGetClassObject = (PFNDLLGETCLASSOBJECT)GetProcAddress(hmMstscax, "DllGetClassObject");
pfnDllCanUnloadNow = (PFNDLLCANUNLOADNOW)GetProcAddress(hmMstscax, "DllCanUnloadNow");
pfnDllGetTscCtlVer = (PFNDLLGETTSCCTLVER)GetProcAddress(hmMstscax, "DllGetTscCtlVer");
}
void dbgprintf(LPCTSTR fmt, ...)
{
TCHAR buf[0x1000];
va_list args;
va_start(args, fmt);
StringCbVPrintf(buf, sizeof(buf), fmt, args);
va_end(args);
StringCbCat(buf, sizeof(buf), TEXT("\n"));
OutputDebugString(buf);
}
#if 0
const IID MsTscAxIIDs[] =
{
IID_IMsRdpClient,
IID_IMsTscAx,
//IID_IMsTscAxEvents,
IID_IMsTscNonScriptable,
IID_IMsRdpClientNonScriptable,
};
const IID MsRdpClient[] =
{
IID_IMsRdpClient,
IID_IMsTscAx,
//IID_IMsTscAxEvents,
IID_IMsTscNonScriptable,
IID_IMsRdpClientNonScriptable,
};
const IID MsRdpClient2[] =
{
IID_IMsRdpClient2,
IID_IMsRdpClient,
IID_IMsTscAx,
//IID_IMsTscAxEvents,
IID_IMsTscNonScriptable,
IID_IMsRdpClientNonScriptable,
};
const IID MsRdpClient3[] =
{
IID_IMsRdpClient3,
IID_IMsRdpClient2,
IID_IMsRdpClient,
IID_IMsTscAx,
//IID_IMsTscAxEvents,
IID_IMsTscNonScriptable,
IID_IMsRdpClientNonScriptable,
};
const IID MsRdpClient4[] =
{
IID_IMsRdpClient4,
IID_IMsRdpClient3,
IID_IMsRdpClient2,
IID_IMsRdpClient,
IID_IMsTscAx,
//IID_IMsTscAxEvents,
IID_IMsTscNonScriptable,
IID_IMsRdpClientNonScriptable,
IID_IMsRdpClientNonScriptable2,
};
#endif
std::wstring UUIDToString(const UUID& uuid)
{
std::wstring s;
LPOLESTR str;
StringFromCLSID(uuid, &str);
s += str;
CoTaskMemFree(str);
return s;
}
std::wstring MonikerToString(IMoniker * pmk)
{
LPOLESTR pszName = NULL;
if(SUCCEEDED(pmk->GetDisplayName(NULL, NULL, &pszName)))
{
std::wstring s(pszName);
CoTaskMemFree(pszName);
return s;
}
else
return std::wstring(L"<error>");
}
std::basic_string<TCHAR> RectToString(const RECT& rc)
{
if(&rc == NULL)
return TEXT("<null>");
std::basic_ostringstream<TCHAR> o;
o << "{" << " left:" << rc.left << " top:" << rc.top << " right:" << rc.right << " bottom:" << rc.bottom << " }";
return o.str();
}
std::basic_string<TCHAR> RectToString(const RECTL& rc)
{
if(&rc == NULL)
return TEXT("<null>");
std::basic_ostringstream<TCHAR> o;
o << "{" << " left:" << rc.left << " top:" << rc.top << " right:" << rc.right << " bottom:" << rc.bottom << " }";
return o.str();
}
std::basic_string<TCHAR> SizeToString(const SIZE& sz)
{
if(&sz == NULL)
return TEXT("<null>");
std::basic_ostringstream<TCHAR> o;
o << "{ " << " cx:" << sz.cx << " cy:" << sz.cy << " }";
return o.str();
}
template<class T> LPCTSTR BooleanToString(const T& X)
{
return X ? TEXT("true") : TEXT("false");
}
std::basic_string<TCHAR> VariantToString(const VARIANT& var)
{
std::basic_ostringstream<TCHAR> o;
switch(var.vt & VT_TYPEMASK)
{
case VT_EMPTY: o << "<empty>"; break;
case VT_NULL: o << "<null>"; break;
case VT_I2: o << "short"; break;
case VT_I4: o << "long"; break;
case VT_R4: o << "float"; break;
case VT_R8: o << "double"; break;
case VT_CY: o << "CURRENCY"; break;
case VT_DATE: o << "DATE"; break;
case VT_BSTR: o << "string"; break;
case VT_DISPATCH: o << "IDispatch *"; break;
case VT_ERROR: o << "SCODE"; break;
case VT_BOOL: o << "bool"; break;
case VT_VARIANT: o << "VARIANT *"; break;
case VT_UNKNOWN: o << "IUnknown *"; break;
case VT_DECIMAL: o << "DECIMAL"; break;
case VT_I1: o << "char"; break;
case VT_UI1: o << "unsigned char"; break;
case VT_UI2: o << "unsigned short"; break;
case VT_UI4: o << "unsigned long"; break;
case VT_I8: o << "long long"; break;
case VT_UI8: o << "unsigned long long"; break;
case VT_INT: o << "int"; break;
case VT_UINT: o << "unsigned int"; break;
case VT_VOID: o << "void"; break;
case VT_HRESULT: o << "HRESULT"; break;
case VT_PTR: o << "void *"; break;
case VT_SAFEARRAY: o << "SAFEARRAY *"; break;
case VT_LPSTR: o << "LPSTR"; break;
case VT_LPWSTR: o << "LPWSTR"; break;
case VT_RECORD: o << "struct { }"; break;
case VT_INT_PTR: o << "intptr_t"; break;
case VT_UINT_PTR: o << "uintptr_t"; break;
case VT_FILETIME: o << "FILETIME"; break;
default: o << "???"; break;
}
if(var.vt & VT_ARRAY)
o << "[]";
else if(var.vt & VT_BYREF)
o << " *";
else
{
switch(var.vt & VT_TYPEMASK)
{
case VT_EMPTY:
case VT_NULL:
case VT_RECORD:
case VT_VOID:
// TODO
case VT_CY:
case VT_DATE:
case VT_DECIMAL:
case VT_FILETIME:
break;
default:
o << " = ";
}
switch(var.vt & VT_TYPEMASK)
{
case VT_I2: o << var.iVal; break;
case VT_I4: o << var.lVal; break;
case VT_R4: o << var.fltVal; break;
case VT_R8: o << var.dblVal; break;
case VT_BSTR: o << std::wstring(var.bstrVal, var.bstrVal + SysStringLen(var.bstrVal)); break;
case VT_BOOL: o << var.boolVal ? "true" : "false"; break;
case VT_I1: o << int(var.cVal); break;
case VT_UI1: o << unsigned int(var.bVal); break;
case VT_UI2: o << var.uiVal; break;
case VT_UI4: o << var.ulVal; break;
case VT_I8: o << var.llVal; break;
case VT_UI8: o << var.ullVal; break;
case VT_INT: o << var.intVal; break;
case VT_UINT: o << var.uintVal; break;
case VT_LPSTR: o << LPSTR(var.byref); break;
case VT_LPWSTR: o << LPWSTR(var.byref); break;
case VT_INT_PTR: o << var.intVal; break; // BUGBUG
case VT_UINT_PTR: o << var.uintVal; break; // BUGBUG
case VT_DISPATCH:
case VT_VARIANT:
case VT_UNKNOWN:
case VT_PTR:
case VT_SAFEARRAY:
case VT_RECORD:
o << var.byref; break;
case VT_ERROR:
case VT_HRESULT:
o << std::hex << var.ulVal; break;
case VT_EMPTY:
case VT_NULL:
case VT_VOID:
break;
default:
assert(0);
}
}
return o.str();
}
#pragma warning(disable:4584)
IConnectionPointContainer * HookIConnectionPointContainer(IConnectionPointContainer * p);
IEnumConnectionPoints * HookIEnumConnectionPoints(IEnumConnectionPoints * p);
IConnectionPoint * HookIConnectionPoint(IConnectionPoint * p);
IEnumConnections * HookIEnumConnections(IEnumConnections * p);
class CConnectionPointContainer: public IConnectionPointContainer
{
private:
LONG m_refCount;
IConnectionPointContainer * m_IConnectionPointContainer;
public:
CConnectionPointContainer(IConnectionPointContainer * pIConnectionPointContainer):
m_refCount(1),
m_IConnectionPointContainer(pIConnectionPointContainer)
{ }
~CConnectionPointContainer() { m_IConnectionPointContainer->Release(); }
virtual HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void ** ppvObject)
{
HRESULT hr = S_OK;
dbgprintf(TEXT("CConnectionPointContainer::QueryInterface(%ls, %p)"), UUIDToString(riid).c_str(), ppvObject);
if(riid == IID_IUnknown || riid == IID_IConnectionPointContainer)
*ppvObject = this;
else
{
*ppvObject = NULL;
hr = E_NOINTERFACE;
}
dbgprintf(TEXT("CConnectionPointContainer::QueryInterface -> %08X, ppvObject = %p"), hr, *ppvObject);
return hr;
}
virtual ULONG STDMETHODCALLTYPE AddRef(void)
{
return InterlockedIncrement(&m_refCount);
}
virtual ULONG STDMETHODCALLTYPE Release(void)
{
LONG n = InterlockedDecrement(&m_refCount);
if(n == 0)
delete this;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -