⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 tcptransport.cpp

📁 Windows CE 6.0 Server 源码
💻 CPP
📖 第 1 页 / 共 3 页
字号:
//
// Copyright (c) Microsoft Corporation.  All rights reserved.
//
//
// Use of this source code is subject to the terms of the Microsoft shared
// source or premium shared source license agreement under which you licensed
// this source code. If you did not accept the terms of the license agreement,
// you are not authorized to use this source code. For the terms of the license,
// please see the license agreement between you and Microsoft or, if applicable,
// see the SOURCE.RTF on your install media or the root of your tools installation.
// THE SOURCE CODE IS PROVIDED "AS IS", WITH NO WARRANTIES.
//
#include "SMB_Globals.h"
#include "TCPTransport.h"
#include "Cracker.h"
#include "CriticalSection.h"
#include "Utils.h"

using namespace TCP_TRANSPORT;

ce::list<CONNECTION_HOLDER *, TCP_CONNECTION_HLDR_ALLOC > TCP_TRANSPORT::g_ConnectionList;
ce::list<LISTEN_NODE, TCP_LISTEN_NODE_ALLOC > TCP_TRANSPORT::g_SocketListenList;

ce::fixed_block_allocator<10> TCP_TRANSPORT::g_ConnectionHolderAllocator;
CRITICAL_SECTION              TCP_TRANSPORT::g_csLockTCPTransportGlobals;
BOOL                          TCP_TRANSPORT::g_fStopped = TRUE;
const USHORT                  TCP_TRANSPORT::g_usTCPListenPort = 445;
const UINT                    TCP_TRANSPORT::g_uiTCPTimeoutInSeconds = 0xFFFFFFFF;//15;
UniqueID                      TCP_TRANSPORT::g_ConnectionID;


IFDBG(LONG                    TCP_TRANSPORT::g_lAliveSockets = 0); 

//
// Forward declares
VOID DecrementConnectionCounter(CONNECTION_HOLDER *pMyConnection);
VOID IncrementConnectionCounter(CONNECTION_HOLDER *pMyConnection);

//
//  Start the TCP transport -- this includes
//     initing any global variables before threads get spun
//     etc
HRESULT StartTCPTransport() 
{
    TRACEMSG(ZONE_INIT, (TEXT("SMBSRV:Starting TCP transport")));
    HRESULT hr = E_FAIL;
    WORD wVersionRequested = MAKEWORD( 2, 2 );
    WSADATA wsaData;

    //
    // Initialize globals
    ASSERT(0 == g_SocketListenList.size());

    g_fStopped = FALSE;

    InitializeCriticalSection(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);

    if (0 != WSAStartup( wVersionRequested, &wsaData )) {
        TRACEMSG(ZONE_INIT, (TEXT("SMBSRV: error with WSAStartup: %d"), WSAGetLastError()));
        hr = E_UNEXPECTED;
        goto Done;
    }

    hr = S_OK;
    Done:
        if(FAILED(hr)) {
            DeleteCriticalSection(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
        }
        return hr;
}

HRESULT StartTCPListenThread(UINT uiIPAddress, BYTE LANA)
{
    HRESULT hr = E_FAIL;
    HANDLE h;
    CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
    csLock.Lock();
    LISTEN_NODE newNode;
    LISTEN_NODE *pNode;

    ASSERT(FALSE == g_fStopped);

    if(!g_SocketListenList.push_front(newNode)) {
        hr = E_OUTOFMEMORY;
        goto Done;
    }
    pNode = &(g_SocketListenList.front());

    if(NULL == (h = CreateThread(NULL, 0, SMBSRV_TCPListenThread, (LPVOID)pNode, CREATE_SUSPENDED, NULL))) {
       TRACEMSG(ZONE_INIT, (L"SMBSRV: CreateThread failed starting TCP Listen:%d", GetLastError()));
       ASSERT(FALSE);
       goto Done;
    }
    pNode->LANA = LANA;
    pNode->s = INVALID_SOCKET;
    pNode->h = h;
    pNode->uiIPAddress = uiIPAddress;

    hr = S_OK;
    Done:

        if(SUCCEEDED(hr)) {
            ResumeThread(pNode->h);
        }
        return hr;
}


HRESULT TerminateTCPListenThread(BYTE LANA) {
    CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
    ce::list<LISTEN_NODE, TCP_LISTEN_NODE_ALLOC >::iterator it;
    HANDLE h = NULL;
    HRESULT hr = E_FAIL;

    csLock.Lock();
        for(it=g_SocketListenList.begin(); it!=g_SocketListenList.end(); ++it) {
            if(it->LANA == LANA) {
                TRACEMSG(ZONE_DETAIL, (TEXT("SMBSRV: cleaning up socket [0x%08x] and handle [0x%08x] for LANA [0x%x]"), it->s, it->h, LANA));
                h = it->h;
                closesocket(it->s);
            }
        }
    csLock.UnLock();

    if(h) {
        if(WAIT_FAILED == WaitForSingleObject(h, INFINITE)) {
            TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV: error with handle [0x%08x] waiting for listen thread on LANA [0x%x] to die!!"), h, LANA));
            hr = E_UNEXPECTED;
            goto Done;
        }
        CloseHandle(h);
    }

    hr = S_OK;
    Done:
        return hr; 
}

HRESULT  HaltTCPIncomingConnections(void)
{
    HRESULT hr = S_OK;
    TRACEMSG(ZONE_INIT, (L"Halting Incoming TCP connections"));

    CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
    ce::list<CONNECTION_HOLDER *, TCP_CONNECTION_HLDR_ALLOC>::iterator itConn;

    if(TRUE == g_fStopped) {
        ASSERT(0 == g_SocketListenList.size());
        return S_OK;
    }

    //
    // Render all threads worthless
    csLock.Lock();
        g_fStopped = TRUE;  
        while(g_SocketListenList.size()) {
            TRACEMSG(ZONE_TCPIP, (L"SMBSRV-LSTNTHREAD: Closing listen socket: 0x%x", g_SocketListenList.front().s));
            csLock.UnLock();
            hr = TerminateTCPListenThread(g_SocketListenList.front().LANA);
            ASSERT(SUCCEEDED(hr));
            csLock.Lock();
        }



        //
        // Three phases
        //
        // 1.  kill all sockets and inc ref cnt so items dont go away from under us
        //     NOTE: this *PREVENTS* DecrementConnectionCounter from deleting their
        //     memory (AND entering the critical section)... so it also prevents
        //     deadlock -- ie, the list length is fixed
        // 2.  wait for all threads to stop (not under CS)
        // 3.  dec all ref cnts (under CS) -- done in StopTCPTransport (after the cracker
        //     has stopped)

        //PHASE1
        ASSERT(TRUE == csLock.IsLocked());
        for(itConn = g_ConnectionList.begin(); itConn != g_ConnectionList.end(); ++itConn) {
            CONNECTION_HOLDER *pToHalt = (*itConn);
            ASSERT(NULL != pToHalt);

            if(pToHalt) {
                //
                // get a refcnt to each connection in the list
                // we use this to gurantee our list CANT shrink.
                IncrementConnectionCounter(pToHalt);
                closesocket(pToHalt->sock);
                pToHalt->sock = INVALID_SOCKET;
            }
        }

        //PHASE2
        for(itConn = g_ConnectionList.begin(); itConn != g_ConnectionList.end(); ++itConn) {
            CONNECTION_HOLDER *pToHalt = (*itConn);

            if(NULL == pToHalt || WAIT_FAILED == WaitForSingleObject(pToHalt->hHandle, INFINITE)) {
                TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV: MAJOR error waiting for connection thread to exit!!")));
                ASSERT(FALSE);
                hr = E_UNEXPECTED;
                goto Done;
            }
        }
    csLock.UnLock();

    Done:
        g_fStopped = FALSE;  //mark us as not stopped... because all sockets (listen and recving)
                            //  have stopped we do this so STopTCPTransport can finish cleaning up
        ASSERT(0 == g_SocketListenList.size());
        return hr;
}

//TODO: WSAStartup called on init, but not on deconstruction
HRESULT StopTCPTransport(void)
{
    TRACEMSG(ZONE_INIT, (TEXT("SMBSRV:Stopping TCP transport")));
    HRESULT hr = S_OK;

    CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);

    if(TRUE == g_fStopped) {
        return S_OK;
    } else {
        g_fStopped = TRUE;
    }

    csLock.Lock();
    ASSERT(0 == g_SocketListenList.size());

    //PHASE3
    //now that both the listen and cracker threads are dead, we know that no NEW sockets (or threads)
    //  can be created... so loop through them all 
    ASSERT(FALSE == Cracker::g_fIsRunning);
    while(0 != g_ConnectionList.size()) {
        CONNECTION_HOLDER *pTemp = g_ConnectionList.front();
        if(NULL == pTemp) {
            ASSERT(FALSE); //internal ERROR!!! we should NEVER have null!! EVER!
            g_ConnectionList.pop_front();
            TRACEMSG(ZONE_TCPIP, (L"SMBSRV: TCP CONNECTIONLIST: %d", g_ConnectionList.size()));
            continue;
        }

        //
        // Do all cleanup for the holder here
        IFDBG(UINT uiPrevLen = g_ConnectionList.size());
        DecrementConnectionCounter(pTemp);
        IFDBG(ASSERT(uiPrevLen - 1 == g_ConnectionList.size()));
    }
    ASSERT(0 == g_ConnectionList.size()); 
    ASSERT(0 == TCP_TRANSPORT::g_ConnectionID.NumIDSOutstanding());
    csLock.UnLock();
    DeleteCriticalSection(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);

    return S_OK;
}




DWORD
TCP_TRANSPORT::SMBSRV_TCPListenThread(LPVOID _myNode)
{
    HRESULT hr = E_FAIL; 
    SOCKADDR_IN my_addr;  
    SOCKET s = INVALID_SOCKET;
    LISTEN_NODE *pMyNode = (LISTEN_NODE*)_myNode;

    UINT uiIPAddr = pMyNode->uiIPAddress;

    CCritSection csLock(&TCP_TRANSPORT::g_csLockTCPTransportGlobals);
    csLock.Lock();

    if(TRUE == g_fStopped) {
        hr = S_OK;
        goto Done;
    }

    if ((s = socket(AF_INET, SOCK_STREAM, 0)) == -1) {
        TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: creating socket FAILED")));
        hr = E_UNEXPECTED;
        goto Done;
    }
    TRACEMSG(ZONE_TCPIP, (L"SMBSRV-LSTNTHREAD: Binding listen socket: 0x%x", s));

    my_addr.sin_family = AF_INET;
    my_addr.sin_port = htons(g_usTCPListenPort);
    my_addr.sin_addr.s_addr = uiIPAddr;

    if (SOCKET_ERROR == bind(s, (struct sockaddr *)&my_addr, sizeof(struct sockaddr))){
        TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: cant bind socket!!")));
        hr = E_UNEXPECTED;
        goto Done;
    }

    if (-1 == listen(s, 10)) {
        TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: cant listen!!")));
        hr = E_UNEXPECTED;
        goto Done;
    }

    //
    // Update the master list with our socket
    pMyNode->s = s;

    //
    // Now that we've added ourself to the global list, release our lock
    csLock.UnLock();

    //
    // Loop as long as we are not stopped
    while (FALSE == g_fStopped) 
    {
        SOCKADDR_IN their_addr;
        int sin_size = sizeof(struct sockaddr_in);
        SOCKET newSock = accept(s, (struct sockaddr *)&their_addr, &sin_size);


        if(INVALID_SOCKET == newSock) {
            TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: got an invalid socket on accept!!!")));
            goto DoneSocketInList;
        } else {
            /*BOOL fDisable = TRUE;
            if (ERROR_SUCCESS != setsockopt(newSock,IPPROTO_TCP,TCP_NODELAY,(const char*)&fDisable,sizeof(BOOL))) {
                TRACEMSG(ZONE_ERROR,(L"HTTPD: setsockopt(0x%08x,TCP_NODELAY) failed",newSock));
                ASSERT(FALSE);
            }*/

            CONNECTION_HOLDER *pNewConn = new CONNECTION_HOLDER();

            TRACEMSG(ZONE_TCPIP, (L"SMBSRV-LSTNTHREAD: got new TCP connection!"));

            if(NULL == pNewConn) {
                TRACEMSG(ZONE_ERROR, (TEXT("SMBSRV-LSTNTHREAD: OOM -- cant make CONNECTION_HOLDER!!!")));
                closesocket(newSock);
                continue;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -