📄 connectionpoint.cpp
字号:
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
//
// Use of this source code is subject to the terms of the Microsoft shared
// source or premium shared source license agreement under which you licensed
// this source code. If you did not accept the terms of the license agreement,
// you are not authorized to use this source code. For the terms of the license,
// please see the license agreement between you and Microsoft or, if applicable,
// see the SOURCE.RTF on your install media or the root of your tools installation.
// THE SOURCE CODE IS PROVIDED "AS IS", WITH NO WARRANTIES.
//
#include "pch.h"
#pragma hdrstop
#include "Winsock2.h"
#include "Iphlpapi.h"
#include "upnp.h"
#include "ssdpapi.h"
#include "Ncbase.h"
#include "ipsupport.h"
#include "upnp_config.h"
#include "auto_xxx.hxx"
#include "variant.h"
#include "HttpRequest.h"
#include "ConnectionPoint.h"
ConnectionPoint* g_pConnectionPoint;
ce::SAXReader* ConnectionPoint::sink::m_pReader;
// ConnectionPoint
ConnectionPoint::ConnectionPoint() :
m_pThreadPool(NULL),
m_bInitialized(false)
{
// create shutdown event
m_hEventShuttingDown = CreateEvent(NULL, true, false, NULL);
m_hStarted = CreateEvent(NULL, false, false, NULL);
// create message queue
MSGQUEUEOPTIONS options = {0};
DWORD nMsgQueue;
options.dwSize = sizeof(MSGQUEUEOPTIONS);
options.dwFlags = MSGQUEUE_ALLOW_BROKEN;
options.cbMaxMessage = MAX_MSG_SIZE;
options.bReadAccess = TRUE;
options.dwMaxMessages = upnp_config::notification_queue_size();
// random number for a unique queue name
srand(GetTickCount());
nMsgQueue = rand();
// create a new message queue
for(DWORD dwError = ERROR_ALREADY_EXISTS; dwError == ERROR_ALREADY_EXISTS; ++nMsgQueue)
{
m_strMsgQueueName.reserve(20);
_snwprintf(m_strMsgQueueName.get_buffer(), m_strMsgQueueName.capacity(), L"UPnPMsgQueue%x", nMsgQueue);
m_hMsgQueue = CreateMsgQueue(m_strMsgQueueName, &options);
dwError = GetLastError();
if(dwError == ERROR_ALREADY_EXISTS)
{
CloseMsgQueue(m_hMsgQueue);
}
}
TraceTag(ttidEvents, "Created new MsgQueue [%s]", m_strMsgQueueName);
assert(m_hMsgQueue.valid());
}
// init
HRESULT ConnectionPoint::init()
{
ce::gate<ce::critical_section> _gate(m_csInit);
assert(!m_pThreadPool);
// make sure the listening_thread is not running
if(WAIT_TIMEOUT == WaitForSingleObject(m_hListeningThread, 5 * 1000))
return E_FAIL;
// init thread pool
if(!(m_pThreadPool = new SVSThreadPool(10))) // max 10 threads
return E_OUTOFMEMORY;
ResetEvent(m_hEventShuttingDown);
// start listening thread
m_hListeningThread = CreateThread(NULL, 0, &listening_thread, this, 0, NULL);
if(!m_hListeningThread.valid())
return E_OUTOFMEMORY;
WaitForSingleObject(m_hStarted, INFINITE);
m_bInitialized = true;
return S_OK;
}
// ~ConnectionPoint
ConnectionPoint::~ConnectionPoint()
{
// all the connections should be unadvised by now but in case there are some left
// we need to make sure that timers are stopped before sink objects are destroyed
// sink object does not stop timer in destructor because we use sink with containers
for(ce::list<sink>::iterator it = m_listSinks.begin(), itEnd = m_listSinks.end(); it != itEnd; ++it)
it->stop_timers();
CloseMsgQueue(m_hMsgQueue);
}
// uninit
void ConnectionPoint::uninit()
{
ce::gate<ce::critical_section> _gate(m_csInit);
if(m_hListeningThread.valid())
{
// signal event so that listening thread exits
SetEvent(m_hEventShuttingDown);
}
delete m_pThreadPool;
m_pThreadPool = NULL;
m_bInitialized = false;
}
// advise
HRESULT ConnectionPoint::advise(LPCWSTR pwszUSN, UINT nLifeTime, ICallback* pCallback, DWORD* pdwCookie)
{
assert(pCallback);
assert(pdwCookie);
HRESULT hr;
ce::gate<ce::critical_section> _gate(m_csListSinks);
if(!m_bInitialized)
if(FAILED(hr = init()))
return hr;
assert(m_hMsgQueue.valid());
assert(m_hListeningThread.valid());
assert(m_pThreadPool);
sink s(pwszUSN, pCallback, *m_pThreadPool);
// register
if(FAILED(hr = s.register_notification(m_strMsgQueueName)))
return hr;
// add sink to the list
m_listSinks.push_front(s);
ce::list<sink>::iterator it = m_listSinks.begin();
it->alive(pwszUSN, NULL, NULL, nLifeTime);
// return iterator as cookie
assert(sizeof(ce::list<sink>::iterator) == sizeof(*pdwCookie));
*((ce::list<sink>::iterator*)pdwCookie) = it;
return S_OK;
}
// unadvise
void ConnectionPoint::unadvise(DWORD dwCookie)
{
assert(sizeof(ce::list<sink>::iterator) == sizeof(dwCookie));
ce::list<sink>::iterator itSink = *((ce::list<sink>::iterator*)&dwCookie);
ce::gate<ce::critical_section> _gate(m_csListSinks);
// can't assume that dwCookie is a valid iterator so I look for it in the list
for(ce::list<sink>::iterator it = m_listSinks.begin(), itEnd = m_listSinks.end(); it != itEnd; ++it)
if(it == itSink)
{
assert(m_bInitialized);
it->stop_timers();
it->unsubscribe();
it->deregister_notification();
m_listSinks.erase(it);
break;
}
if(m_bInitialized && m_listSinks.empty())
{
uninit();
//
// leave m_csListSinks critical section because listening_thread
// may need to enter it to exit
//
_gate.leave();
// wait for listening_thread to exit
DWORD dw = WaitForSingleObject(m_hListeningThread, 60 * 1000);
assert(WAIT_TIMEOUT != dw);
}
}
// subscribe
HRESULT ConnectionPoint::subscribe(DWORD dwCookie, LPCSTR pszURL)
{
assert(sizeof(ce::list<sink>::iterator) == sizeof(dwCookie));
ce::list<sink>::iterator itSink = *((ce::list<sink>::iterator*)&dwCookie);
ce::gate<ce::critical_section> _gate(m_csListSinks);
// can't assume that dwCookie is a valid iterator so I look for it in the list
for(ce::list<sink>::iterator it = m_listSinks.begin(), itEnd = m_listSinks.end(); it != itEnd; ++it)
if(it == itSink)
{
assert(m_bInitialized);
return it->subscribe(pszURL);
}
return E_FAIL;
}
// dispatch_message
void ConnectionPoint::dispatch_message(const event_msg_hdr& hdr)
{
bool msgFail = FALSE;
DWORD dw, dwFlags;
ce::wstring strEventMessage;
ce::wstring strUSN;
ce::wstring strLocation;
ce::wstring strNLS;
wchar_t pwchBuffer[MAX_MSG_SIZE / sizeof(wchar_t)];
HANDLE pObjects[2] = {m_hEventShuttingDown, m_hMsgQueue};
// reserve memory to avoid reallocations
// keep track of whether string operations on the event message fails
// we want to keep reading the events from the queue even if there is a local failure
msgFail = strEventMessage.reserve(hdr.nNumberOfBlocks * MAX_MSG_SIZE / sizeof(wchar_t));
// read the body of event message
for(int i = 0; i < hdr.nNumberOfBlocks; ++i)
{
dw = WaitForMultipleObjects(sizeof(pObjects)/sizeof(*pObjects), pObjects, false, INFINITE);
if(WAIT_OBJECT_0 == dw)
// shutdown event is signaled - return
return;
assert(dw == WAIT_OBJECT_0 + 1);
if(!ReadMsgQueue(m_hMsgQueue, pwchBuffer, sizeof(pwchBuffer), &dw, INFINITE, &dwFlags))
{
#ifdef DEBUG
dw = GetLastError();
#endif
assert(0);
// should never get here
// we don't know how to recover so just return
return;
}
if(dw == 1)
{
// one byte message is a terminator message
// writer could not finish writing message for this event, we have to abort
// next message in the queue will be header for next event
assert(*((char*)pwchBuffer) == 0);
TraceTag(ttidError, "Detected terminate message in events msg queue; discarding current event.");
return;
}
switch(hdr.nt)
{
case NOTIFY_PROP_CHANGE:
// append to the event message body
msgFail = msgFail && strEventMessage.append(pwchBuffer, dw/sizeof(*pwchBuffer));
break;
case NOTIFY_ALIVE:
// two messages in byebye body: USN and Location
switch(i)
{
case 0:
// USN
msgFail = msgFail && strUSN.assign(pwchBuffer, dw/sizeof(*pwchBuffer));
break;
case 1:
// LOCATION
msgFail = msgFail && strLocation.assign(pwchBuffer, dw/sizeof(*pwchBuffer));
break;
case 2:
// NLS
msgFail = msgFail && strNLS.assign(pwchBuffer, dw/sizeof(*pwchBuffer));
break;
default:
ASSERT(FALSE);
}
break;
case NOTIFY_BYEBYE:
// just one message in byebye body: USN
ASSERT(i == 0);
msgFail = msgFail && strUSN.assign(pwchBuffer, dw/sizeof(*pwchBuffer));
break;
default:
ASSERT(FALSE);
break;
}
}
// if there was an error constructing the message then we don't want to attempt to process it
if(!msgFail)
{
TraceTag(ttidError, "Failed reading event/alive/or byebye message into buffer.");
return;
}
ce::gate<ce::critical_section> _gate(m_csListSinks);
// find sink for the subscription
for(ce::list<sink>::iterator it = m_listSinks.begin(), itEnd = m_listSinks.end(); it != itEnd; ++it)
if(it->getSubscriptionHandle() == hdr.hSubscription)
{
switch(hdr.nt)
{
case NOTIFY_PROP_CHANGE:
it->event(strEventMessage, hdr.dwEventSEQ);
break;
case NOTIFY_ALIVE:
it->alive(strUSN, strLocation, strNLS, hdr.dwLifeTime);
break;
case NOTIFY_BYEBYE:
it->byebye(strUSN);
break;
default:
ASSERT(FALSE);
break;
}
break;
}
}
// listening_thread
DWORD WINAPI ConnectionPoint::listening_thread(void* p)
{
assert(ConnectionPoint::sink::m_pReader == NULL);
DWORD errCode = ERROR_SUCCESS;
HINSTANCE hLib;
ConnectionPoint* pThis = reinterpret_cast<ConnectionPoint*>(p);
HANDLE pObjects[2] = {pThis->m_hEventShuttingDown, pThis->m_hMsgQueue};
CoInitializeEx(NULL, COINIT_MULTITHREADED);
hLib = LoadLibrary(L"upnpctrl.dll");
if(!hLib)
{
errCode = ERROR_OUTOFMEMORY;
goto Finish;
}
SetEvent(pThis->m_hStarted);
srand(GetTickCount());
ConnectionPoint::sink::m_pReader = new ce::SAXReader;
if(!ConnectionPoint::sink::m_pReader)
{
errCode = ERROR_OUTOFMEMORY;
goto Finish;
}
for(;;)
{
event_msg_hdr hdr;
DWORD dw, dwFlags;
dw = WaitForMultipleObjects(sizeof(pObjects)/sizeof(*pObjects), pObjects, false, INFINITE);
if(WAIT_OBJECT_0 == dw)
break;
assert(dw == WAIT_OBJECT_0 + 1);
// read message header
if(ReadMsgQueue(pThis->m_hMsgQueue, &hdr, sizeof(hdr), &dw, INFINITE, &dwFlags))
{
assert(dw == sizeof(hdr));
pThis->dispatch_message(hdr);
}
else
{
// should not get here
TraceTag(ttidError, "ReadMsgQueue failed on events msg queue. Error code (%d).", GetLastError());
assert(0);
}
}
Finish:
if(ConnectionPoint::sink::m_pReader)
{
delete ConnectionPoint::sink::m_pReader;
ConnectionPoint::sink::m_pReader = NULL;
}
CoUninitialize();
if(hLib)
FreeLibraryAndExitThread(hLib, 0);
return errCode;
}
// register_notification
HRESULT ConnectionPoint::sink::register_notification(LPCWSTR pwszMsgQueueName)
{
// generate unique query string
LONGLONG uuid64 = GenerateUUID64();
swprintf(m_pwchQueryString, L"notify%.8x-%.8x", (DWORD)uuid64, *((DWORD *)&uuid64 + 1));
sprintf(m_pchQueryString, "notify%.8x-%.8x", (DWORD)uuid64, *((DWORD *)&uuid64 + 1));
// register notification
m_hSubscription = RegisterNotification(NOTIFY_BYEBYE | NOTIFY_ALIVE | NOTIFY_PROP_CHANGE, m_strUSN, m_pwchQueryString, pwszMsgQueueName);
if(!m_hSubscription)
return HRESULT_FROM_WIN32(GetLastError());
return S_OK;
}
// deregister_notification
void ConnectionPoint::sink::deregister_notification()
{
// deregister notification
DeregisterNotification(m_hSubscription);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -