📄 main.cpp
字号:
int WSPAPI WSPSend (
SOCKET s,
LPWSABUF lpBuffers,
DWORD dwBufferCount,
LPDWORD lpNumberOfBytesSent,
DWORD dwFlags,
LPWSAOVERLAPPED lpOverlapped,
LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
LPWSATHREADID lpThreadId,
LPINT lpErrno
)
{
INT Ret;
SOCK_INFO *SocketContext;
LPWSAOVERLAPPED ProviderOverlapped;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
// Check for overlapped I/O
if (lpOverlapped)
{
ProviderOverlapped = GetOverlappedStructure(s, SocketContext->ProviderSocket, lpOverlapped, lpCompletionRoutine,
lpThreadId, &SocketContext->BytesSent);
Ret = NextProcTable.lpWSPSend(SocketContext->ProviderSocket, lpBuffers, dwBufferCount,
lpNumberOfBytesSent, dwFlags, ProviderOverlapped, NULL, NULL, lpErrno);
}
else
{
Ret = NextProcTable.lpWSPSend(SocketContext->ProviderSocket, lpBuffers, dwBufferCount,
lpNumberOfBytesSent, dwFlags, lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
if (Ret != SOCKET_ERROR)
{
SocketContext->BytesSent += *lpNumberOfBytesSent;
}
}
return Ret;
}
int WSPAPI WSPSendDisconnect(
SOCKET s,
LPWSABUF lpOutboundDisconnectData,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPSendDisconnect(SocketContext->ProviderSocket,
lpOutboundDisconnectData, lpErrno);
}
int WSPAPI WSPSendTo(
SOCKET s,
LPWSABUF lpBuffers,
DWORD dwBufferCount,
LPDWORD lpNumberOfBytesSent,
DWORD dwFlags,
const struct sockaddr FAR * lpTo,
int iTolen,
LPWSAOVERLAPPED lpOverlapped,
LPWSAOVERLAPPED_COMPLETION_ROUTINE lpCompletionRoutine,
LPWSATHREADID lpThreadId,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
int Ret;
LPWSAOVERLAPPED ProviderOverlapped;
// Check for overlapped I/O
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
if (lpOverlapped)
{
ProviderOverlapped = GetOverlappedStructure(s, SocketContext->ProviderSocket, lpOverlapped, lpCompletionRoutine,
lpThreadId, &SocketContext->BytesSent);
Ret = NextProcTable.lpWSPSendTo(SocketContext->ProviderSocket, lpBuffers, dwBufferCount,
lpNumberOfBytesSent, dwFlags, lpTo, iTolen, ProviderOverlapped, NULL, NULL, lpErrno);
}
else
{
Ret = NextProcTable.lpWSPSendTo(SocketContext->ProviderSocket, lpBuffers, dwBufferCount,
lpNumberOfBytesSent, dwFlags, lpTo, iTolen, lpOverlapped, lpCompletionRoutine, lpThreadId, lpErrno);
if (Ret != SOCKET_ERROR)
{
SocketContext->BytesSent += *lpNumberOfBytesSent;
}
}
return Ret;
}
int WSPAPI WSPSetSockOpt(
SOCKET s,
int level,
int optname,
const char FAR * optval,
int optlen,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPSetSockOpt(SocketContext->ProviderSocket, level,
optname, optval, optlen, lpErrno);
}
int WSPAPI WSPShutdown (
SOCKET s,
int how,
LPINT lpErrno)
{
SOCK_INFO *SocketContext;
if (MainUpCallTable.lpWPUQuerySocketHandleContext(s, (LPDWORD) &SocketContext, lpErrno) == SOCKET_ERROR)
return SOCKET_ERROR;
return NextProcTable.lpWSPShutdown(SocketContext->ProviderSocket, how, lpErrno);
}
int WSPAPI WSPStringToAddress(
LPWSTR AddressString,
INT AddressFamily,
LPWSAPROTOCOL_INFOW lpProtocolInfo,
LPSOCKADDR lpAddress,
LPINT lpAddressLength,
LPINT lpErrno)
{
return NextProcTable.lpWSPStringToAddress (AddressString, AddressFamily,
&gBaseInfo[0], lpAddress, lpAddressLength, lpErrno);
}
SOCKET WSPAPI WSPSocket(
int af,
int type,
int protocol,
LPWSAPROTOCOL_INFOW lpProtocolInfo,
GROUP g,
DWORD dwFlags,
LPINT lpErrno
)
{
SOCKET NextProviderSocket;
SOCKET NewSocket;
SOCK_INFO *SocketContext;
LPWSAPROTOCOL_INFOW pInfo=NULL;
INT iProtocol, iSockType, i;
iProtocol = (!lpProtocolInfo ? lpProtocolInfo->iProtocol : protocol);
iSockType = (!lpProtocolInfo ? lpProtocolInfo->iSocketType : type);
for(i=0; i < gLayerCount ;i++)
{
if ((gBaseInfo[i].iSocketType == iSockType) &&
(gBaseInfo[i].iProtocol == iProtocol))
{
pInfo = &gBaseInfo[i];
break;
}
}
NextProviderSocket = NextProcTable.lpWSPSocket(af, type, protocol, (pInfo ? pInfo : lpProtocolInfo),
g, dwFlags, lpErrno);
if (NextProviderSocket != INVALID_SOCKET)
{
if ((SocketContext = (SOCK_INFO *) GlobalAlloc(GPTR, sizeof SOCK_INFO)) == NULL)
{
*lpErrno = WSAENOBUFS;
return INVALID_SOCKET;
}
SocketContext->ProviderSocket = NextProviderSocket;
SocketContext->bClosing = FALSE;
SocketContext->dwOutstandingAsync = 0;
SocketContext->BytesRecv = 0;
SocketContext->BytesSent = 0;
NewSocket = MainUpCallTable.lpWPUCreateSocketHandle(gChainId, (DWORD) SocketContext, lpErrno);
{
TCHAR buffer[128];
wsprintf(buffer, L"Creating socket %d\n", NewSocket);
OutputDebugString(buffer);
}
return NewSocket;
}
return INVALID_SOCKET;
}
int WSPAPI WSPStartup(
WORD wVersion,
LPWSPDATA lpWSPData,
LPWSAPROTOCOL_INFOW lpProtocolInfo,
WSPUPCALLTABLE UpCallTable,
LPWSPPROC_TABLE lpProcTable)
{
INT ReturnCode = 0;
WCHAR ProviderPath[MAX_PATH];
INT ProviderPathLen = MAX_PATH;
WCHAR LibraryPath[MAX_PATH];
INT i, j, x, y, z;
INT TotalProtocols, idx;
INT Error;
DWORD NextProviderCatId;
UINT iBaseId;
LPWSAPROTOCOL_INFOW ProtocolInfo;
LPWSAPROTOCOL_INFOW ProtoInfo = lpProtocolInfo;
LPWSPSTARTUP WSPStartupFunc = NULL;
EnterCriticalSection(&gCriticalSection);
MainUpCallTable = UpCallTable;
// Load Next Provider in chain if this is the first time called
if (!gEntryCount)
{
OutputDebugString(L"Layered Service Provider\n");
// Get all protocol information in database
if ((ProtocolInfo = GetProviders(&TotalProtocols)) == NULL)
{
return WSAEPROVIDERFAILEDINIT;
}
// Find out what our layered protocol catalog ID entry is
for (i = 0; i < TotalProtocols; i++)
if (memcmp (&ProtocolInfo[i].ProviderId, &ProviderGuid, sizeof (GUID))==0)
{
gLayerCatId = ProtocolInfo[i].dwCatalogEntryId;
break;
}
// Save our protocol chains catalog ID entry
gChainId = lpProtocolInfo->dwCatalogEntryId;
gLayerCount=0;
for(x=0; x < TotalProtocols ;x++)
{
for(y=0; y < ProtocolInfo[x].ProtocolChain.ChainLen ;y++)
{
if (gLayerCatId == ProtocolInfo[x].ProtocolChain.ChainEntries[y])
{
gLayerCount++;
break;
}
}
}
gBaseInfo = (LPWSAPROTOCOL_INFOW)GlobalAlloc(GPTR, sizeof(WSAPROTOCOL_INFOW)*gLayerCount);
if (!gBaseInfo)
{
return WSAENOBUFS;
}
idx=0;
for(x=0; x < TotalProtocols ;x++)
{
for(y=0; y < ProtocolInfo[x].ProtocolChain.ChainLen ;y++)
{
if (gLayerCatId == ProtocolInfo[x].ProtocolChain.ChainEntries[y])
{
// Our LSP exists in this entries chain
//
iBaseId = ProtocolInfo[x].ProtocolChain.ChainEntries[ProtocolInfo[x].ProtocolChain.ChainLen-1];
for(z=0; z < TotalProtocols ;z++)
{
if (ProtocolInfo[z].dwCatalogEntryId == iBaseId)
{
memcpy(&gBaseInfo[idx++], &ProtocolInfo[z], sizeof(WSAPROTOCOL_INFOW));
OutputDebugString(gBaseInfo[idx-1].szProtocol);
OutputDebugString(L"\n");
}
}
}
}
}
// Find our layered catalog ID entry in the protocol chain
for(j = 0; j < lpProtocolInfo->ProtocolChain.ChainLen; j++)
{
if (lpProtocolInfo->ProtocolChain.ChainEntries[j] == gLayerCatId)
{
NextProviderCatId = lpProtocolInfo->ProtocolChain.ChainEntries[j + 1];
break;
}
}
// Find next provider path to load
for (i = 0; i < TotalProtocols; i++)
if (NextProviderCatId == ProtocolInfo[i].dwCatalogEntryId)
{
if (WSCGetProviderPath(&ProtocolInfo[i].ProviderId, ProviderPath, &ProviderPathLen, &Error) == SOCKET_ERROR)
{
return WSAEPROVIDERFAILEDINIT;
}
break;
}
if (!ExpandEnvironmentStrings(ProviderPath, LibraryPath, MAX_PATH))
{
return WSAEPROVIDERFAILEDINIT;
}
if ((hProvider = LoadLibrary(LibraryPath)) == NULL)
{
return WSAEPROVIDERFAILEDINIT;
}
if((WSPStartupFunc = (LPWSPSTARTUP) GetProcAddress(hProvider, "WSPStartup")) == NULL)
{
return WSAEPROVIDERFAILEDINIT;
}
ReturnCode = (*WSPStartupFunc)(wVersion, lpWSPData, ProtoInfo, UpCallTable, lpProcTable);
// Save the next providers procedure table
memcpy(&NextProcTable, lpProcTable, sizeof WSPPROC_TABLE);
// Remap service provider functions here
lpProcTable->lpWSPAccept = WSPAccept;
lpProcTable->lpWSPAddressToString = WSPAddressToString;
lpProcTable->lpWSPAsyncSelect = WSPAsyncSelect;
lpProcTable->lpWSPBind = WSPBind;
lpProcTable->lpWSPCancelBlockingCall = WSPCancelBlockingCall;
lpProcTable->lpWSPCleanup = WSPCleanup;
lpProcTable->lpWSPCloseSocket = WSPCloseSocket;
lpProcTable->lpWSPConnect = WSPConnect;
lpProcTable->lpWSPDuplicateSocket = WSPDuplicateSocket;
lpProcTable->lpWSPEnumNetworkEvents = WSPEnumNetworkEvents;
lpProcTable->lpWSPEventSelect = WSPEventSelect;
lpProcTable->lpWSPGetOverlappedResult = WSPGetOverlappedResult;
lpProcTable->lpWSPGetPeerName = WSPGetPeerName;
lpProcTable->lpWSPGetSockOpt = WSPGetSockOpt;
lpProcTable->lpWSPGetSockName = WSPGetSockName;
lpProcTable->lpWSPGetQOSByName = WSPGetQOSByName;
lpProcTable->lpWSPIoctl = WSPIoctl;
lpProcTable->lpWSPJoinLeaf = WSPJoinLeaf;
lpProcTable->lpWSPListen = WSPListen;
lpProcTable->lpWSPRecv = WSPRecv;
lpProcTable->lpWSPRecvDisconnect = WSPRecvDisconnect;
lpProcTable->lpWSPRecvFrom = WSPRecvFrom;
lpProcTable->lpWSPSelect = WSPSelect;
lpProcTable->lpWSPSend = WSPSend;
lpProcTable->lpWSPSendDisconnect = WSPSendDisconnect;
lpProcTable->lpWSPSendTo = WSPSendTo;
lpProcTable->lpWSPSetSockOpt = WSPSetSockOpt;
lpProcTable->lpWSPShutdown = WSPShutdown;
lpProcTable->lpWSPSocket = WSPSocket;
lpProcTable->lpWSPStringToAddress = WSPStringToAddress;
gWSPData = lpWSPData;
gProcTable = lpProcTable;
} else
{
lpWSPData = gWSPData;
lpProcTable = gProcTable;
ReturnCode = 0;
}
gEntryCount++;
LeaveCriticalSection(&gCriticalSection);
return(ReturnCode);
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -