📄 spi.cpp
字号:
// THIS CODE AND INFORMATION IS PROVIDED "AS IS" WITHOUT WARRANTY OF
// ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING BUT NOT LIMITED TO
// THE IMPLIED WARRANTIES OF MERCHANTABILITY AND/OR FITNESS FOR A
// PARTICULAR PURPOSE.
//
// Copyright (C) 1999 Microsoft Corporation. All Rights Reserved.
//
// Module Name: spi.cpp
//
// Description:
//
// This sample illustrates how to develop a layered service provider that is
// capable of counting all bytes transmitted through an IP socket. The application
// reports when sockets are created and reports how many bytes were sent and
// received when a socket closes. The results are reported using the OutputDebugString
// API which will allow you to intercept the I/O by using a debugger such as cdb.exe
// or you can monitor the I/O using dbmon.exe.
//
// This file contains the 30 SPI functions you are required to implement in a
// service provider. It also contains the two functions that must be exported
// from the DLL module DllMain and WSPStartup.
//
#include "provider.h"
#include "install.h"
#include <stdio.h>
#include <stdlib.h>
//
// Globals used across files
//
CRITICAL_SECTION gCriticalSection,
gOverlappedCS,
gDebugCritSec;
WSPUPCALLTABLE MainUpCallTable;
HINSTANCE hDllInstance = NULL;
LPPROVIDER gBaseInfo = NULL;
INT gLayerCount=0; // Number of base providers we're layered over
HANDLE hLspHeap=NULL;
extern HANDLE ghIocp; // Handle to IO completion port
void FreeSocketsAndMemory(int *lpErrno);
//
// Need to keep track of which PROVIDERs that are currently executing
// a blocking Winsock call on a per thread basis.
//
#define SetBlockingProvider(Provider) \
(TlsIndex!=0xFFFFFFFF) \
? TlsSetValue (TlsIndex, Provider) \
: NULL
//
// Globals local to this file
//
static DWORD TlsIndex=0xFFFFFFFF;
static DWORD gEntryCount = 0; // how many times WSPStartup has been called
static DWORD gLayerCatId = 0; // Catalog ID of our dummy entry
static WSPDATA gWSPData;
static WSPPROC_TABLE gProcTable;
static BOOL bDetached=FALSE;
static TCHAR Msg[512]; // For outputting debug messages
void dbgprint(char *format,...)
{
//#ifdef DEBUG
static DWORD pid=0;
va_list vl;
char dbgbuf1[2048],
dbgbuf2[2048];
if (pid == 0)
{
pid = GetCurrentProcessId();
}
EnterCriticalSection(&gDebugCritSec);
va_start(vl, format);
wvsprintf(dbgbuf1, format, vl);
wsprintf(dbgbuf2, "%lu: %s\r\n", pid, dbgbuf1);
va_end(vl);
OutputDebugString(dbgbuf2);
LeaveCriticalSection(&gDebugCritSec);
//#endif
}
void PrintProcTable(LPWSPPROC_TABLE lpProcTable)
{
#ifdef DBG_PRINTPROCTABLE
dbgprint("WSPAccept = 0x%X", lpProcTable->lpWSPAccept);
dbgprint("WSPAddressToString = 0x%X", lpProcTable->lpWSPAddressToString);
dbgprint("WSPAsyncSelect = 0x%X", lpProcTable->lpWSPAsyncSelect);
dbgprint("WSPBind = 0x%X", lpProcTable->lpWSPBind);
dbgprint("WSPCancelBlockingCall = 0x%X", lpProcTable->lpWSPCancelBlockingCall);
dbgprint("WSPCleanup = 0x%X", lpProcTable->lpWSPCleanup);
dbgprint("WSPCloseSocket = 0x%X", lpProcTable->lpWSPCloseSocket);
dbgprint("WSPConnect = 0x%X", lpProcTable->lpWSPConnect);
dbgprint("WSPDuplicateSocket = 0x%X", lpProcTable->lpWSPDuplicateSocket);
dbgprint("WSPAccept = 0x%X", lpProcTable->lpWSPEnumNetworkEvents);
dbgprint("WSPEventSelect = 0x%X", lpProcTable->lpWSPEventSelect);
dbgprint("WSPGetOverlappedResult = 0x%X", lpProcTable->lpWSPGetOverlappedResult);
dbgprint("WSPGetPeerName = 0x%X", lpProcTable->lpWSPGetPeerName);
dbgprint("WSPGetSockOpt = 0x%X", lpProcTable->lpWSPGetSockOpt);
dbgprint("WSPGetSockName = 0x%X", lpProcTable->lpWSPGetSockName);
dbgprint("WSPGetQOSByName = 0x%X", lpProcTable->lpWSPGetQOSByName);
dbgprint("WSPIoctl = 0x%X", lpProcTable->lpWSPIoctl);
dbgprint("WSPJoinLeaf = 0x%X", lpProcTable->lpWSPJoinLeaf);
dbgprint("WSPListen = 0x%X", lpProcTable->lpWSPListen);
dbgprint("WSPRecv = 0x%X", lpProcTable->lpWSPRecv);
dbgprint("WSPRecvDisconnect = 0x%X", lpProcTable->lpWSPRecvDisconnect);
dbgprint("WSPRecvFrom = 0x%X", lpProcTable->lpWSPRecvFrom);
dbgprint("WSPSelect = 0x%X", lpProcTable->lpWSPSelect);
dbgprint("WSPSend = 0x%X", lpProcTable->lpWSPSend);
dbgprint("WSPSendDisconnect = 0x%X", lpProcTable->lpWSPSendDisconnect);
dbgprint("WSPSendTo = 0x%X", lpProcTable->lpWSPSendTo);
dbgprint("WSPSetSockOpt = 0x%X", lpProcTable->lpWSPSetSockOpt);
dbgprint("WSPShutdown = 0x%X", lpProcTable->lpWSPShutdown);
dbgprint("WSPSocket = 0x%X", lpProcTable->lpWSPSocket);
dbgprint("WSPStringToAddress = 0x%X", lpProcTable->lpWSPStringToAddress);
#endif
}
int VerifyProcTable(LPWSPPROC_TABLE lpProcTable)
{
if ( lpProcTable->lpWSPAccept &&
lpProcTable->lpWSPAddressToString &&
lpProcTable->lpWSPAsyncSelect &&
lpProcTable->lpWSPBind &&
lpProcTable->lpWSPCancelBlockingCall &&
lpProcTable->lpWSPCleanup &&
lpProcTable->lpWSPCloseSocket &&
lpProcTable->lpWSPConnect &&
lpProcTable->lpWSPDuplicateSocket &&
lpProcTable->lpWSPEnumNetworkEvents &&
lpProcTable->lpWSPEventSelect &&
lpProcTable->lpWSPGetOverlappedResult &&
lpProcTable->lpWSPGetPeerName &&
lpProcTable->lpWSPGetSockOpt &&
lpProcTable->lpWSPGetSockName &&
lpProcTable->lpWSPGetQOSByName &&
lpProcTable->lpWSPIoctl &&
lpProcTable->lpWSPJoinLeaf &&
lpProcTable->lpWSPListen &&
lpProcTable->lpWSPRecv &&
lpProcTable->lpWSPRecvDisconnect &&
lpProcTable->lpWSPRecvFrom &&
lpProcTable->lpWSPSelect &&
lpProcTable->lpWSPSend &&
lpProcTable->lpWSPSendDisconnect &&
lpProcTable->lpWSPSendTo &&
lpProcTable->lpWSPSetSockOpt &&
lpProcTable->lpWSPShutdown &&
lpProcTable->lpWSPSocket &&
lpProcTable->lpWSPStringToAddress)
{
return NO_ERROR;
}
return SOCKET_ERROR;
}
//
// Function: DllMain
//
// Description:
// Provides initialization when the LSP DLL is loaded. In our case we simply,
// initialize some critical sections used throughout the DLL.
//
BOOL WINAPI DllMain(IN HINSTANCE hinstDll, IN DWORD dwReason, LPVOID lpvReserved)
{
switch (dwReason)
{
case DLL_PROCESS_ATTACH:
hDllInstance = hinstDll;
//
// Initialize some critical section objects
//
InitializeCriticalSection(&gCriticalSection);
InitializeCriticalSection(&gOverlappedCS);
InitializeCriticalSection(&gDebugCritSec);
TlsIndex = TlsAlloc();
break;
case DLL_THREAD_ATTACH:
break;
case DLL_THREAD_DETACH:
break;
case DLL_PROCESS_DETACH:
bDetached = TRUE;
EnterCriticalSection(&gCriticalSection);
if (gBaseInfo)
{
int Error;
FreeSocketsAndMemory(&Error);
}
LeaveCriticalSection(&gCriticalSection);
DeleteCriticalSection(&gCriticalSection);
DeleteCriticalSection(&gOverlappedCS);
DeleteCriticalSection(&gDebugCritSec);
if (lpvReserved == NULL)
{
if (TlsIndex != 0xFFFFFFFF)
{
TlsFree(TlsIndex);
TlsIndex = 0xFFFFFFFF;
}
}
break;
}
return TRUE;
}
//
// Function: WSPAccept
//
// Description:
// Handle the WSAAccept function. The only special consideration here is the
// conditional accept callback. You can choose to intercept this by substituting
// your own callback (you'll need to keep track of the user supplied callback so
// you can trigger that once your substituted function is triggered).
//
SOCKET WSPAPI WSPAccept (
SOCKET s,
struct sockaddr FAR * addr,
LPINT addrlen,
LPCONDITIONPROC lpfnCondition,
DWORD_PTR dwCallbackData,
LPINT lpErrno)
{
SOCKET NewProviderSocket;
SOCKET NewSocket = INVALID_SOCKET;
SOCK_INFO *NewSocketContext;
SOCK_INFO *SocketContext;
// Query for our per socket info
//
SocketContext = FindAndLockSocketContext(s, lpErrno);
if (SocketContext == NULL)
{
*lpErrno = WSAENOTSOCK;
return INVALID_SOCKET;
}
//
// Note: You can subsitute your own conditional accept callback function
// in order to intercept this callback. You would have to keep track
// of the user's callback function so that you can call that when
// your intermediate function executes.
//
SetBlockingProvider(SocketContext->Provider);
NewProviderSocket = SocketContext->Provider->NextProcTable.lpWSPAccept(
SocketContext->ProviderSocket,
addr,
addrlen,
lpfnCondition,
dwCallbackData,
lpErrno);
SetBlockingProvider(NULL);
if (NewProviderSocket != INVALID_SOCKET)
{
// The underlying provider received a new connection so lets create our own
// socket to pass back up to the application.
//
if ((NewSocketContext = CreateSockInfo(SocketContext->Provider,
NewProviderSocket,
SocketContext)) == NULL)
{
*lpErrno = WSAENOBUFS;
}
else
{
if ((NewSocket = MainUpCallTable.lpWPUCreateSocketHandle(
SocketContext->Provider->LayeredProvider.ProtocolChain.ChainEntries[0],
(DWORD_PTR) NewSocketContext,
lpErrno)) == INVALID_SOCKET)
{
dbgprint("WSPAccept(): WPUCreateSocketHandle() failed: %d", *lpErrno);
}
NewSocketContext->LayeredSocket = NewSocket;
dbgprint("Creating socket %d", NewSocket);
}
}
UnlockSocketContext(SocketContext, lpErrno);
return NewSocket;
}
//
// Function: WSPAdressToString
//
// Description:
// Convert an address to string. We simply pass this to the lower provider.
//
int WSPAPI WSPAddressToString(
LPSOCKADDR lpsaAddress,
DWORD dwAddressLength,
LPWSAPROTOCOL_INFOW lpProtocolInfo,
LPWSTR lpszAddressString,
LPDWORD lpdwAddressStringLength,
LPINT lpErrno)
{
WSAPROTOCOL_INFOW *pInfo=NULL;
PROVIDER *Provider=NULL;
INT i, ret;
// First find the appropriate provider
//
for(i=0; i < gLayerCount ;i++)
{
if ((gBaseInfo[i].NextProvider.iAddressFamily == lpProtocolInfo->iAddressFamily) &&
(gBaseInfo[i].NextProvider.iSocketType == lpProtocolInfo->iSocketType) &&
(gBaseInfo[i].NextProvider.iProtocol == lpProtocolInfo->iProtocol))
{
if (lpProtocolInfo)
{
// In case of multiple providers check the provider flags
if ( (gBaseInfo[i].NextProvider.dwServiceFlags1 & ~XP1_IFS_HANDLES) !=
(lpProtocolInfo->dwServiceFlags1 & ~XP1_IFS_HANDLES) )
{
continue;
}
}
Provider = &gBaseInfo[i];
pInfo = &gBaseInfo[i].NextProvider;
break;
}
}
if (Provider == NULL)
{
*lpErrno = WSAEINVAL;
return SOCKET_ERROR;
}
// Of course if the next layer isn't a base just pass down lpProtocolInfo.
//
if (pInfo->ProtocolChain.ChainLen != BASE_PROTOCOL)
{
pInfo = lpProtocolInfo;
}
SetBlockingProvider(Provider);
ret = Provider->NextProcTable.lpWSPAddressToString(lpsaAddress,
dwAddressLength,
pInfo,
lpszAddressString,
lpdwAddressStringLength,
lpErrno);
SetBlockingProvider(NULL);
return ret;
}
//
// Function: WSPAsyncSelect
//
// Description:
// Register specific Winsock events with a socket. We need to substitute
// the app socket with the provider socket and use our own hidden window.
//
int WSPAPI WSPAsyncSelect (
SOCKET s,
HWND hWnd,
unsigned int wMsg,
long lEvent,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
HWND hWorkerWindow;
INT ret;
// Make sure the window handle is valid
//
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -