📄 tcptransport.cpp
字号:
//
// Copyright (c) Microsoft Corporation. All rights reserved.
//
//
// Use of this source code is subject to the terms of the Microsoft shared
// source or premium shared source license agreement under which you licensed
// this source code. If you did not accept the terms of the license agreement,
// you are not authorized to use this source code. For the terms of the license,
// please see the license agreement between you and Microsoft or, if applicable,
// see the SOURCE.RTF on your install media or the root of your tools installation.
// THE SOURCE CODE IS PROVIDED "AS IS", WITH NO WARRANTIES.
//
#include "SMB_Globals.h"
#include "TCPTransport.h"
#include "Cracker.h"
#include "CriticalSection.h"
#include "Utils.h"
using namespace TCP_TRANSPORT;
ce::list<CONNECTION_HOLDER *, TCP_CONNECTION_HLDR_ALLOC > TCP_TRANSPORT::g_ConnectionList;
ce::list<LISTEN_NODE, TCP_LISTEN_NODE_ALLOC > TCP_TRANSPORT::g_SocketListenList;
ce::fixed_block_allocator<10> TCP_TRANSPORT::g_ConnectionHolderAllocator;
CRITICAL_SECTION TCP_TRANSPORT::g_csLockTCPTransportGlobals;
BOOL TCP_TRANSPORT::g_fStopped = TRUE;
const USHORT TCP_TRANSPORT::g_usTCPListenPort = 445;
const UINT TCP_TRANSPORT::g_uiTCPTimeoutInSeconds = 0xFFFFFFFF;//15;
UniqueID TCP_TRANSPORT::g_ConnectionID;
IFDBG(LONG TCP_TRANSPORT::g_lAliveSockets = 0);
//
// Forward declares
VOID DecrementConnectionCounter(CONNECTION_HOLDER *pMyConnection);
VOID IncrementConnectionCounter(CONNECTION_HOLDER *pMyConnection);
//
// Start the TCP transport -- this includes
// initing any global variables before threads get spun
// etc
HRESULT StartTCPTransport()
{
TRACEMSG(ZONE_INIT, (TEXT("SMBSRV:Starting TCP transport")));
HRESULT hr = E_FAIL;
WORD wVersionRequested = MAKEWORD( 2, 2 );
WSADATA wsaData;
//
// Initialize globals
ASSERT(0 == g_SocketListenList.size());
g_fStopped = FALSE;
InitializeCriticalSection(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
if (0 != WSAStartup( wVersionRequested, &wsaData )) {
TRACEMSG(ZONE_INIT, (TEXT("SMBSRV: error with WSAStartup: %d"), WSAGetLastError()));
hr = E_UNEXPECTED;
goto Done;
}
hr = S_OK;
Done:
if(FAILED(hr)) {
DeleteCriticalSection(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
}
return hr;
}
HRESULT StartTCPListenThread(UINT uiIPAddress, BYTE LANA)
{
HRESULT hr = E_FAIL;
HANDLE h;
CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
csLock.Lock();
LISTEN_NODE newNode;
LISTEN_NODE *pNode;
ASSERT(FALSE == g_fStopped);
if(!g_SocketListenList.push_front(newNode)) {
hr = E_OUTOFMEMORY;
goto Done;
}
pNode = &(g_SocketListenList.front());
if(NULL == (h = CreateThread(NULL, 0, SMBSRV_TCPListenThread, (LPVOID)pNode, CREATE_SUSPENDED, NULL))) {
TRACEMSG(ZONE_INIT, (L"SMBSRV: CreateThread failed starting TCP Listen:%d", GetLastError()));
ASSERT(FALSE);
goto Done;
}
pNode->LANA = LANA;
pNode->s = INVALID_SOCKET;
pNode->h = h;
pNode->uiIPAddress = uiIPAddress;
hr = S_OK;
Done:
if(SUCCEEDED(hr)) {
ResumeThread(pNode->h);
}
return hr;
}
HRESULT TerminateTCPListenThread(BYTE LANA) {
CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
ce::list<LISTEN_NODE, TCP_LISTEN_NODE_ALLOC >::iterator it;
HANDLE h = NULL;
HRESULT hr = E_FAIL;
csLock.Lock();
for(it=g_SocketListenList.begin(); it!=g_SocketListenList.end(); ++it) {
if(it->LANA == LANA) {
TRACEMSG(ZONE_DETAIL, (TEXT("SMBSRV: cleaning up socket [0x%08x] and handle [0x%08x] for LANA [0x%x]"), it->s, it->h, LANA));
h = it->h;
closesocket(it->s);
}
}
csLock.UnLock();
if(h) {
if(WAIT_FAILED == WaitForSingleObject(h, INFINITE)) {
TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV: error with handle [0x%08x] waiting for listen thread on LANA [0x%x] to die!!"), h, LANA));
hr = E_UNEXPECTED;
goto Done;
}
CloseHandle(h);
}
hr = S_OK;
Done:
return hr;
}
HRESULT HaltTCPIncomingConnections(void)
{
HRESULT hr = S_OK;
TRACEMSG(ZONE_INIT, (L"Halting Incoming TCP connections"));
CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
ce::list<CONNECTION_HOLDER *, TCP_CONNECTION_HLDR_ALLOC>::iterator itConn;
if(TRUE == g_fStopped) {
ASSERT(0 == g_SocketListenList.size());
return S_OK;
}
//
// Render all threads worthless
csLock.Lock();
g_fStopped = TRUE;
while(g_SocketListenList.size()) {
TRACEMSG(ZONE_TCPIP, (L"SMBSRV-LSTNTHREAD: Closing listen socket: 0x%x", g_SocketListenList.front().s));
csLock.UnLock();
hr = TerminateTCPListenThread(g_SocketListenList.front().LANA);
ASSERT(SUCCEEDED(hr));
csLock.Lock();
}
//
// Three phases
//
// 1. kill all sockets and inc ref cnt so items dont go away from under us
// NOTE: this *PREVENTS* DecrementConnectionCounter from deleting their
// memory (AND entering the critical section)... so it also prevents
// deadlock -- ie, the list length is fixed
// 2. wait for all threads to stop (not under CS)
// 3. dec all ref cnts (under CS) -- done in StopTCPTransport (after the cracker
// has stopped)
//PHASE1
ASSERT(TRUE == csLock.IsLocked());
for(itConn = g_ConnectionList.begin(); itConn != g_ConnectionList.end(); ++itConn) {
CONNECTION_HOLDER *pToHalt = (*itConn);
ASSERT(NULL != pToHalt);
if(pToHalt) {
//
// get a refcnt to each connection in the list
// we use this to gurantee our list CANT shrink.
IncrementConnectionCounter(pToHalt);
closesocket(pToHalt->sock);
pToHalt->sock = INVALID_SOCKET;
}
}
//PHASE2
for(itConn = g_ConnectionList.begin(); itConn != g_ConnectionList.end(); ++itConn) {
CONNECTION_HOLDER *pToHalt = (*itConn);
if(NULL == pToHalt || WAIT_FAILED == WaitForSingleObject(pToHalt->hHandle, INFINITE)) {
TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV: MAJOR error waiting for connection thread to exit!!")));
ASSERT(FALSE);
hr = E_UNEXPECTED;
goto Done;
}
}
csLock.UnLock();
Done:
g_fStopped = FALSE; //mark us as not stopped... because all sockets (listen and recving)
// have stopped we do this so STopTCPTransport can finish cleaning up
ASSERT(0 == g_SocketListenList.size());
return hr;
}
//TODO: WSAStartup called on init, but not on deconstruction
HRESULT StopTCPTransport(void)
{
TRACEMSG(ZONE_INIT, (TEXT("SMBSRV:Stopping TCP transport")));
HRESULT hr = S_OK;
CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
if(TRUE == g_fStopped) {
return S_OK;
} else {
g_fStopped = TRUE;
}
csLock.Lock();
ASSERT(0 == g_SocketListenList.size());
//PHASE3
//now that both the listen and cracker threads are dead, we know that no NEW sockets (or threads)
// can be created... so loop through them all
ASSERT(FALSE == Cracker::g_fIsRunning);
while(0 != g_ConnectionList.size()) {
CONNECTION_HOLDER *pTemp = g_ConnectionList.front();
if(NULL == pTemp) {
ASSERT(FALSE); //internal ERROR!!! we should NEVER have null!! EVER!
g_ConnectionList.pop_front();
TRACEMSG(ZONE_TCPIP, (L"SMBSRV: TCP CONNECTIONLIST: %d", g_ConnectionList.size()));
continue;
}
//
// Do all cleanup for the holder here
IFDBG(UINT uiPrevLen = g_ConnectionList.size());
DecrementConnectionCounter(pTemp);
IFDBG(ASSERT(uiPrevLen - 1 == g_ConnectionList.size()));
}
ASSERT(0 == g_ConnectionList.size());
ASSERT(0 == TCP_TRANSPORT::g_ConnectionID.NumIDSOutstanding());
csLock.UnLock();
DeleteCriticalSection(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
return S_OK;
}
DWORD
TCP_TRANSPORT::SMBSRV_TCPListenThread(LPVOID _myNode)
{
HRESULT hr = E_FAIL;
SOCKADDR_IN my_addr;
SOCKET s = INVALID_SOCKET;
LISTEN_NODE *pMyNode = (LISTEN_NODE*)_myNode;
UINT uiIPAddr = pMyNode->uiIPAddress;
CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
csLock.Lock();
if(TRUE == g_fStopped) {
hr = S_OK;
goto Done;
}
if ((s = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: creating socket FAILED")));
hr = E_UNEXPECTED;
goto Done;
}
TRACEMSG(ZONE_TCPIP, (L"SMBSRV-LSTNTHREAD: Binding listen socket: 0x%x", s));
my_addr.sin_family = AF_INET;
my_addr.sin_port = htons(g_usTCPListenPort);
my_addr.sin_addr.s_addr = uiIPAddr;
if (SOCKET_ERROR == bind(s, (struct sockaddr *)&my_addr, sizeof(struct sockaddr))){
TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: cant bind socket!!")));
hr = E_UNEXPECTED;
goto Done;
}
if (-1 == listen(s, 10)) {
TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: cant listen!!")));
hr = E_UNEXPECTED;
goto Done;
}
//
// Update the master list with our socket
pMyNode->s = s;
//
// Now that we've added ourself to the global list, release our lock
csLock.UnLock();
//
// Loop as long as we are not stopped
while (FALSE == g_fStopped)
{
SOCKADDR_IN their_addr;
int sin_size = sizeof(struct sockaddr_in);
SOCKET newSock = accept(s, (struct sockaddr *)&their_addr, &sin_size);
if(INVALID_SOCKET == newSock) {
TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: got an invalid socket on accept!!!")));
goto DoneSocketInList;
} else {
/*BOOL fDisable = TRUE;
if (ERROR_SUCCESS != setsockopt(newSock,IPPROTO_TCP,TCP_NODELAY,(const char*)&fDisable,sizeof(BOOL))) {
TRACEMSG(ZONE_ERROR,(L"HTTPD: setsockopt(0x%08x,TCP_NODELAY) failed",newSock));
ASSERT(FALSE);
}*/
CONNECTION_HOLDER *pNewConn = new CONNECTION_HOLDER();
TRACEMSG(ZONE_TCPIP, (L"SMBSRV-LSTNTHREAD: got new TCP connection!"));
if(NULL == pNewConn) {
TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: OOM -- cant make CONNECTION_HOLDER!!!")));
closesocket(newSock);
continue;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -