📄 main.cpp
字号:
// Module Name: main.cpp
//
// Description:
//
// This sample illustrates how to develop a layered service provider that is
// capable of counting all bytes transmitted through an IP socket. The application
// reports when sockets are created and reports how many bytes were sent and
// received when a socket closes. The results are reported using the OutputDebugString
// API which will allow you to intercept the I/O by using a debugger such as cdb.exe
// or you can monitor the I/O using dbmon.exe.
//
// This file contains the 30 SPI functions you are required to implement in a
// service provider. It also contains the two functions that must be exported
// from the DLL module DllMain and WSPStartup.
//
//
// Compile:
//
// This project is managed through the LSP.DSW project file.
//
// Execute:
//
// This project produces a DLL named lsp.dll. This dll should be copied to the
// %SystemRoot%\System32 directory. Once the file is in place you should execute
// the application instlsp.exe to insert this provider in the Winsock 2 catalog
// of service providers.
//
#include "provider.h"
WSPUPCALLTABLE MainUpCallTable;
DWORD gLayerCatId = 0;
DWORD gChainId = 0;
DWORD gEntryCount = 0;
CRITICAL_SECTION gCriticalSection;
LPWSPDATA gWSPData = NULL;
WSPPROC_TABLE NextProcTable;
LPWSPPROC_TABLE gProcTable = NULL;
LPWSAPROTOCOL_INFOW gBaseInfo = NULL;
HINSTANCE HDllInstance = NULL;
HINSTANCE hProvider = NULL;
INT gLayerCount=0; // Number of base providers we're layered over
static TCHAR Msg[512];
BOOL WINAPI DllMain(IN HINSTANCE hinstDll, IN DWORD dwReason, LPVOID lpvReserved)
{
switch (dwReason)
{
case DLL_PROCESS_ATTACH:
HDllInstance = hinstDll;
InitializeCriticalSection(&gCriticalSection);
// InitAsyncSelectCS();
InitOverlappedCS();
break;
case DLL_THREAD_ATTACH:
break;
case DLL_THREAD_DETACH:
break;
case DLL_PROCESS_DETACH:
break;
}
return TRUE;
}
SOCKET WSPAPI WSPAccept (
SOCKET s,
struct sockaddr FAR * addr,
LPINT addrlen,
LPCONDITIONPROC lpfnCondition,
DWORD dwCallbackData,
LPINT lpErrno)
{
SOCKET NewProviderSocket;
SOCKET NewSocket;
SOCK_INFO *NewSocketContext;
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return INVALID_SOCKET;
NewProviderSocket = NextProcTable.lpWSPAccept(SocketContext->ProviderSocket, addr, addrlen,
lpfnCondition, dwCallbackData, lpErrno);
if (NewProviderSocket != INVALID_SOCKET)
{
if ((NewSocketContext = (SOCK_INFO *) GlobalAlloc(GPTR, sizeof SOCK_INFO)) == NULL)
{
*lpErrno = WSAENOBUFS;
return INVALID_SOCKET;
}
NewSocketContext->ProviderSocket = NewProviderSocket;
NewSocketContext->bClosing = FALSE;
NewSocketContext->dwOutstandingAsync = 0;
NewSocketContext->BytesRecv = 0;
NewSocketContext->BytesSent = 0;
if ((NewSocket = MainUpCallTable.lpWPUCreateSocketHandle(gChainId, (DWORD) NewSocketContext, lpErrno)) != INVALID_SOCKET)
DuplicateAsyncSocket(SocketContext->ProviderSocket, NewProviderSocket, NewSocket);
{
TCHAR buffer[128];
wsprintf(buffer, L"Creating socket %d\n", NewSocket);
OutputDebugString(buffer);
}
return NewSocket;
}
return INVALID_SOCKET;
}
int WSPAPI WSPAddressToString(
LPSOCKADDR lpsaAddress,
DWORD dwAddressLength,
LPWSAPROTOCOL_INFOW lpProtocolInfo,
LPWSTR lpszAddressString,
LPDWORD lpdwAddressStringLength,
LPINT lpErrno)
{
return NextProcTable.lpWSPAddressToString(lpsaAddress, dwAddressLength,
&gBaseInfo[0], lpszAddressString, lpdwAddressStringLength, lpErrno);
}
int WSPAPI WSPAsyncSelect (
SOCKET s,
HWND hWnd,
unsigned int wMsg,
long lEvent,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
HWND hWorkerWindow;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
if ((hWorkerWindow = SetWorkerWindow(SocketContext->ProviderSocket, s, hWnd, wMsg)) == NULL)
return SOCKET_ERROR;
return NextProcTable.lpWSPAsyncSelect(SocketContext->ProviderSocket, hWorkerWindow, WM_SOCKET, lEvent, lpErrno);
}
int WSPAPI WSPBind(
SOCKET s,
const struct sockaddr FAR * name,
int namelen,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPBind(SocketContext->ProviderSocket, name, namelen, lpErrno);
}
int WSPAPI WSPCancelBlockingCall(
LPINT lpErrno)
{
return NextProcTable.lpWSPCancelBlockingCall(lpErrno);
}
int WSPAPI WSPCleanup (
LPINT lpErrno
)
{
int Ret;
if (!gEntryCount)
{
*lpErrno = WSANOTINITIALISED;
return SOCKET_ERROR;
}
Ret = NextProcTable.lpWSPCleanup(lpErrno);
EnterCriticalSection(&gCriticalSection);
gEntryCount--;
if (gEntryCount == 0)
{
FreeLibrary(hProvider);
hProvider = NULL;
}
LeaveCriticalSection(&gCriticalSection);
return Ret;
}
int WSPAPI WSPCloseSocket (
SOCKET s,
LPINT lpErrno
)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
if (SocketContext->dwOutstandingAsync != 0)
{
SocketContext->bClosing = TRUE;
if (NextProcTable.lpWSPCloseSocket(SocketContext->ProviderSocket, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return 0;
}
if (NextProcTable.lpWSPCloseSocket(SocketContext->ProviderSocket, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
RemoveSockInfo(SocketContext->ProviderSocket);
if (MainUpCallTable.lpWPUCloseSocketHandle(s, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
{
TCHAR buffer[128];
wsprintf(buffer, L"Closing socket %d Bytes Sent [%lu] Bytes Recv [%lu]\n", s,
SocketContext->BytesSent, SocketContext->BytesRecv);
OutputDebugString(buffer);
}
GlobalFree(SocketContext);
return 0;
}
int WSPAPI WSPConnect (
SOCKET s,
const struct sockaddr FAR * name,
int namelen,
LPWSABUF lpCallerData,
LPWSABUF lpCalleeData,
LPQOS lpSQOS,
LPQOS lpGQOS,
LPINT lpErrno
)
{
SOCK_INFO *SocketContext;
INT ret;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
{
return SOCKET_ERROR;
}
ret = NextProcTable.lpWSPConnect(SocketContext->ProviderSocket, name, namelen, lpCallerData, lpCalleeData,
lpSQOS, lpGQOS, lpErrno);
return ret;
}
int WSPAPI WSPDuplicateSocket(
SOCKET s,
DWORD dwProcessId,
LPWSAPROTOCOL_INFOW lpProtocolInfo,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPDuplicateSocket(SocketContext->ProviderSocket,
dwProcessId, lpProtocolInfo, lpErrno);
}
int WSPAPI WSPEnumNetworkEvents(
SOCKET s,
WSAEVENT hEventObject,
LPWSANETWORKEVENTS lpNetworkEvents,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPEnumNetworkEvents(SocketContext->ProviderSocket,
hEventObject, lpNetworkEvents, lpErrno);
}
int WSPAPI WSPEventSelect(
SOCKET s,
WSAEVENT hEventObject,
long lNetworkEvents,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPEventSelect(SocketContext->ProviderSocket, hEventObject,
lNetworkEvents, lpErrno);
}
BOOL WSPAPI WSPGetOverlappedResult (
SOCKET s,
LPWSAOVERLAPPED lpOverlapped,
LPDWORD lpcbTransfer,
BOOL fWait,
LPDWORD lpdwFlags,
LPINT lpErrno)
{
DWORD Ret;
if (lpOverlapped->Internal!=WSS_OPERATION_IN_PROGRESS)
{
*lpcbTransfer = lpOverlapped->InternalHigh;
*lpdwFlags = lpOverlapped->Offset;
*lpErrno = lpOverlapped->OffsetHigh;
return(lpOverlapped->OffsetHigh == 0 ? TRUE : FALSE);
}
else
if (fWait)
{
Ret = WaitForSingleObject(lpOverlapped->hEvent, INFINITE);
if ((Ret == WAIT_OBJECT_0)
&& (lpOverlapped->Internal != WSS_OPERATION_IN_PROGRESS))
{
*lpcbTransfer = lpOverlapped->InternalHigh;
*lpdwFlags = lpOverlapped->Offset;
*lpErrno = lpOverlapped->OffsetHigh;
return(lpOverlapped->OffsetHigh == 0 ? TRUE : FALSE);
}
else
*lpErrno = WSASYSCALLFAILURE;
}
else
*lpErrno = WSA_IO_INCOMPLETE;
return FALSE;
}
int WSPAPI WSPGetPeerName(
SOCKET s,
struct sockaddr FAR * name,
LPINT namelen,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPGetPeerName(SocketContext->ProviderSocket, name,
namelen, lpErrno);
}
int WSPAPI WSPGetSockName(
SOCKET s,
struct sockaddr FAR * name,
LPINT namelen,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPGetSockName(SocketContext->ProviderSocket, name,
namelen, lpErrno);
}
int WSPAPI WSPGetSockOpt(
SOCKET s,
int level,
int optname,
char FAR * optval,
LPINT optlen,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -