⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 mstscax.cpp

📁 ReactOS是一些高手根据Windows XP的内核编写出的类XP。内核实现机理和API函数调用几乎相同。甚至可以兼容XP的程序。喜欢研究系统内核的人可以看一看。
💻 CPP
📖 第 1 页 / 共 5 页
字号:
#include "stdafx.h"

namespace
{
	using namespace MSTSCLib;

	typedef HRESULT (STDAPICALLTYPE * PFNDLLGETCLASSOBJECT)(IN REFCLSID rclsid, IN REFIID riid, OUT LPVOID FAR * ppv);
	typedef HRESULT (STDAPICALLTYPE * PFNDLLCANUNLOADNOW)(void);
	typedef ULONG (STDAPICALLTYPE * PFNDLLGETTSCCTLVER)(void);

	PFNDLLGETCLASSOBJECT pfnDllGetClassObject = NULL;
	PFNDLLCANUNLOADNOW pfnDllCanUnloadNow = NULL;
	PFNDLLGETTSCCTLVER pfnDllGetTscCtlVer = NULL;

	HMODULE hmMstscax = NULL;

	extern "C" char __ImageBase;
	static const HMODULE hmSelf = reinterpret_cast<HMODULE>(&__ImageBase);

	void init()
	{
		if(hmMstscax)
			return;

		TCHAR szFileName[MAX_PATH + 1];
		GetModuleFileName(hmSelf, szFileName, MAX_PATH);

		std::basic_string<TCHAR> strFileName(&szFileName[0]);
		std::reverse_iterator<std::basic_string<TCHAR>::const_iterator > begin(strFileName.end());
		std::reverse_iterator<std::basic_string<TCHAR>::const_iterator > end(strFileName.begin());
		std::basic_string<TCHAR>::const_iterator endPath = std::find(begin, end, TEXT('\\')).base();

		std::basic_string<TCHAR> strPath(strFileName.begin(), endPath);
		strPath.append(TEXT("original\\mstscax.dll"));

		hmMstscax = LoadLibrary(strPath.c_str());
		pfnDllGetClassObject = (PFNDLLGETCLASSOBJECT)GetProcAddress(hmMstscax, "DllGetClassObject");
		pfnDllCanUnloadNow = (PFNDLLCANUNLOADNOW)GetProcAddress(hmMstscax, "DllCanUnloadNow");
		pfnDllGetTscCtlVer = (PFNDLLGETTSCCTLVER)GetProcAddress(hmMstscax, "DllGetTscCtlVer");
	}

	void dbgprintf(LPCTSTR fmt, ...)
	{
		TCHAR buf[0x1000];
		
		va_list args;
		va_start(args, fmt);
		StringCbVPrintf(buf, sizeof(buf), fmt, args);
		va_end(args);

		StringCbCat(buf, sizeof(buf), TEXT("\n"));

		OutputDebugString(buf);
	}

#if 0
	const IID MsTscAxIIDs[] =
	{
		IID_IMsRdpClient,
		IID_IMsTscAx,
		//IID_IMsTscAxEvents,
		IID_IMsTscNonScriptable,
		IID_IMsRdpClientNonScriptable,
	};

	const IID MsRdpClient[] =
	{
		IID_IMsRdpClient,
		IID_IMsTscAx,
		//IID_IMsTscAxEvents,
		IID_IMsTscNonScriptable,
		IID_IMsRdpClientNonScriptable,
	};

	const IID MsRdpClient2[] =
	{
		IID_IMsRdpClient2,
		IID_IMsRdpClient,
		IID_IMsTscAx,
		//IID_IMsTscAxEvents,
		IID_IMsTscNonScriptable,
		IID_IMsRdpClientNonScriptable,
	};

	const IID MsRdpClient3[] =
	{
		IID_IMsRdpClient3,
		IID_IMsRdpClient2,
		IID_IMsRdpClient,
		IID_IMsTscAx,
		//IID_IMsTscAxEvents,
		IID_IMsTscNonScriptable,
		IID_IMsRdpClientNonScriptable,
	};

	const IID MsRdpClient4[] =
	{
		IID_IMsRdpClient4,
		IID_IMsRdpClient3,
		IID_IMsRdpClient2,
		IID_IMsRdpClient,
		IID_IMsTscAx,
		//IID_IMsTscAxEvents,
		IID_IMsTscNonScriptable,
		IID_IMsRdpClientNonScriptable,
		IID_IMsRdpClientNonScriptable2,
	};
#endif

	std::wstring UUIDToString(const UUID& uuid)
	{
		std::wstring s;
		LPOLESTR str;
		StringFromCLSID(uuid, &str);
		s += str;
		CoTaskMemFree(str);
		return s;
	}

	std::wstring MonikerToString(IMoniker * pmk)
	{
		LPOLESTR pszName = NULL;

		if(SUCCEEDED(pmk->GetDisplayName(NULL, NULL, &pszName)))
		{
			std::wstring s(pszName);
			CoTaskMemFree(pszName);
			return s;
		}
		else
			return std::wstring(L"<error>");
	}

	std::basic_string<TCHAR> RectToString(const RECT& rc)
	{
		if(&rc == NULL)
			return TEXT("<null>");

		std::basic_ostringstream<TCHAR> o;
		o << "{" << " left:" << rc.left << " top:" << rc.top << " right:" << rc.right << " bottom:" << rc.bottom << " }";
		return o.str();
	}

	std::basic_string<TCHAR> RectToString(const RECTL& rc)
	{
		if(&rc == NULL)
			return TEXT("<null>");

		std::basic_ostringstream<TCHAR> o;
		o << "{" << " left:" << rc.left << " top:" << rc.top << " right:" << rc.right << " bottom:" << rc.bottom << " }";
		return o.str();
	}

	std::basic_string<TCHAR> SizeToString(const SIZE& sz)
	{
		if(&sz == NULL)
			return TEXT("<null>");

		std::basic_ostringstream<TCHAR> o;
		o << "{ " << " cx:" << sz.cx << " cy:" << sz.cy << " }";
		return o.str();
	}

	template<class T> LPCTSTR BooleanToString(const T& X)
	{
		return X ? TEXT("true") : TEXT("false");
	}

	std::basic_string<TCHAR> VariantToString(const VARIANT& var)
	{
		std::basic_ostringstream<TCHAR> o;

		switch(var.vt & VT_TYPEMASK)
		{
		case VT_EMPTY:           o << "<empty>"; break;
		case VT_NULL:            o << "<null>"; break;
		case VT_I2:              o << "short"; break;
		case VT_I4:              o << "long"; break;
		case VT_R4:              o << "float"; break;
		case VT_R8:              o << "double"; break;
		case VT_CY:              o << "CURRENCY"; break;
		case VT_DATE:            o << "DATE"; break;
		case VT_BSTR:            o << "string"; break;
		case VT_DISPATCH:        o << "IDispatch *"; break;
		case VT_ERROR:           o << "SCODE"; break;
		case VT_BOOL:            o << "bool"; break;
		case VT_VARIANT:         o << "VARIANT *"; break;
		case VT_UNKNOWN:         o << "IUnknown *"; break;
		case VT_DECIMAL:         o << "DECIMAL"; break;
		case VT_I1:              o << "char"; break;
		case VT_UI1:             o << "unsigned char"; break;
		case VT_UI2:             o << "unsigned short"; break;
		case VT_UI4:             o << "unsigned long"; break;
		case VT_I8:              o << "long long"; break;
		case VT_UI8:             o << "unsigned long long"; break;
		case VT_INT:             o << "int"; break;
		case VT_UINT:            o << "unsigned int"; break;
		case VT_VOID:            o << "void"; break;
		case VT_HRESULT:         o << "HRESULT"; break;
		case VT_PTR:             o << "void *"; break;
		case VT_SAFEARRAY:       o << "SAFEARRAY *"; break;
		case VT_LPSTR:           o << "LPSTR"; break;
		case VT_LPWSTR:          o << "LPWSTR"; break;
		case VT_RECORD:          o << "struct { }"; break;
		case VT_INT_PTR:         o << "intptr_t"; break;
		case VT_UINT_PTR:        o << "uintptr_t"; break;
		case VT_FILETIME:        o << "FILETIME"; break;
		default:                 o << "???"; break;
		}

		if(var.vt & VT_ARRAY)
			o << "[]";
		else if(var.vt & VT_BYREF)
			o << " *";
		else
		{
			switch(var.vt & VT_TYPEMASK)
			{
			case VT_EMPTY:
			case VT_NULL:
			case VT_RECORD:
			case VT_VOID:

				// TODO
			case VT_CY:
			case VT_DATE:
			case VT_DECIMAL:
			case VT_FILETIME:
				break;

			default:
				o << " = ";
			}

			switch(var.vt & VT_TYPEMASK)
			{
			case VT_I2:       o << var.iVal; break;
			case VT_I4:       o << var.lVal; break;
			case VT_R4:       o << var.fltVal; break;
			case VT_R8:       o << var.dblVal; break;
			case VT_BSTR:     o << std::wstring(var.bstrVal, var.bstrVal + SysStringLen(var.bstrVal)); break;
			case VT_BOOL:     o << var.boolVal ? "true" : "false"; break;
			case VT_I1:       o << int(var.cVal); break;
			case VT_UI1:      o << unsigned int(var.bVal); break;
			case VT_UI2:      o << var.uiVal; break;
			case VT_UI4:      o << var.ulVal; break;
			case VT_I8:       o << var.llVal; break;
			case VT_UI8:      o << var.ullVal; break;
			case VT_INT:      o << var.intVal; break;
			case VT_UINT:     o << var.uintVal; break;
			case VT_LPSTR:    o << LPSTR(var.byref); break;
			case VT_LPWSTR:   o << LPWSTR(var.byref); break;
			case VT_INT_PTR:  o << var.intVal; break; // BUGBUG
			case VT_UINT_PTR: o << var.uintVal; break; // BUGBUG

			case VT_DISPATCH:
			case VT_VARIANT:
			case VT_UNKNOWN:
			case VT_PTR:
			case VT_SAFEARRAY:
			case VT_RECORD:
				o << var.byref; break;

			case VT_ERROR:
			case VT_HRESULT:
				o << std::hex << var.ulVal; break;

			case VT_EMPTY:
			case VT_NULL:
			case VT_VOID:
				break;

			default:
				assert(0);
			}
		}

		return o.str();
	}

#pragma warning(disable:4584)

	IConnectionPointContainer * HookIConnectionPointContainer(IConnectionPointContainer * p);
	IEnumConnectionPoints * HookIEnumConnectionPoints(IEnumConnectionPoints * p);
	IConnectionPoint * HookIConnectionPoint(IConnectionPoint * p);
	IEnumConnections * HookIEnumConnections(IEnumConnections * p);

	class CConnectionPointContainer: public IConnectionPointContainer
	{
	private:
		LONG m_refCount;
		IConnectionPointContainer * m_IConnectionPointContainer;

	public:
		CConnectionPointContainer(IConnectionPointContainer * pIConnectionPointContainer):
			m_refCount(1),
			m_IConnectionPointContainer(pIConnectionPointContainer)
		{ }

		~CConnectionPointContainer() { m_IConnectionPointContainer->Release(); }

		virtual HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, void ** ppvObject)
		{
			HRESULT hr = S_OK;

			dbgprintf(TEXT("CConnectionPointContainer::QueryInterface(%ls, %p)"), UUIDToString(riid).c_str(), ppvObject);

			if(riid == IID_IUnknown || riid == IID_IConnectionPointContainer)
				*ppvObject = this;
			else
			{
				*ppvObject = NULL;
				hr = E_NOINTERFACE;
			}
		
			dbgprintf(TEXT("CConnectionPointContainer::QueryInterface -> %08X, ppvObject = %p"), hr, *ppvObject);
			return hr;
		}

		virtual ULONG STDMETHODCALLTYPE AddRef(void)
		{
			return InterlockedIncrement(&m_refCount);
		}

		virtual ULONG STDMETHODCALLTYPE Release(void)
		{
			LONG n = InterlockedDecrement(&m_refCount);

			if(n == 0)
				delete this;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -