📄 drvfltip.cpp
字号:
// DrvFltIp.cpp文件
extern "C"
{
#include <ntddk.h>
#include <ntddndis.h>
#include <pfhook.h>
}
#include <stdio.h>
#include "DrvFltIp.h"
#include "internal.h"
/////////////////////////////////////////////////////////////////////////////////////////////
// 自定义函数的声明
NTSTATUS DispatchCreateClose(PDEVICE_OBJECT pDevObj, PIRP pIrp);
void DriverUnload(PDRIVER_OBJECT pDriverObj);
NTSTATUS DispatchIoctl(PDEVICE_OBJECT pDevObj, PIRP pIrp);
NTSTATUS AddFilterToList(CIPFilter* pFilter);
void ClearFilterList();
NTSTATUS SetFilterFunction(PacketFilterExtensionPtr filterFun);
PF_FORWARD_ACTION FilterPackets(unsigned char*, unsigned char*, unsigned int, unsigned int, unsigned int, IPAddr, IPAddr);
void inet_ntoa_krnl(unsigned long addr, char* out);
void ntohs(unsigned short port, unsigned short* out);
// 过滤列表首地址
struct CFilterList* g_pHeader = NULL;
// 封包数据储存地址
PacketInfo* g_pPacket = 0;
int g_nPacketCount = 0;
int g_nPacketMaxCount = 10;
//////////////////////////////////////////////////////////////////////////////////////////////////
// 驱动内部名称和符号连接名称
#define DEVICE_NAME L"\\Device\\devDrvFltIp"
#define LINK_NAME L"\\??\\DrvFltIp"
///////////////////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////////////////
extern "C" NTSTATUS DriverEntry(PDRIVER_OBJECT pDriverObj, PUNICODE_STRING pRegistryString)
{
NTSTATUS status = STATUS_SUCCESS;
// 初始化各个派遣例程
pDriverObj->MajorFunction[IRP_MJ_CREATE] = DispatchCreateClose;
pDriverObj->MajorFunction[IRP_MJ_CLOSE] = DispatchCreateClose;
pDriverObj->MajorFunction[IRP_MJ_DEVICE_CONTROL] = DispatchIoctl;
pDriverObj->DriverUnload = DriverUnload;
// 创建、初始化设备对象
// 设备名称
UNICODE_STRING ustrDevName;
RtlInitUnicodeString(&ustrDevName, DEVICE_NAME);
// 创建设备对象
PDEVICE_OBJECT pDevObj;
status = IoCreateDevice(pDriverObj,
0,
&ustrDevName,
FILE_DEVICE_DRVFLTIP,
0,
FALSE,
&pDevObj);
if(!NT_SUCCESS(status))
{
return status;
}
// 创建符号连接名称
// 符号连接名称
UNICODE_STRING ustrLinkName;
RtlInitUnicodeString(&ustrLinkName, LINK_NAME);
// 创建关联
status = IoCreateSymbolicLink(&ustrLinkName, &ustrDevName);
if(!NT_SUCCESS(status))
{
IoDeleteDevice(pDevObj);
return status;
}
DbgPrint("DriverEntry例程初始化完毕!\n");
SetFilterFunction(FilterPackets);
CIPFilter filter;
memset(&filter, 0, sizeof(CIPFilter));
filter.bDrop = FALSE;
AddFilterToList(&filter);
DbgPrint("加载回调函数成功!\n添加过滤规则成功!");
return STATUS_SUCCESS;
}
void DriverUnload(PDRIVER_OBJECT pDriverObj)
{
// 卸载过滤函数
SetFilterFunction(NULL);
// 释放所有资源
ClearFilterList();
// 删除符号连接名称
UNICODE_STRING strLink;
RtlInitUnicodeString(&strLink, LINK_NAME);
IoDeleteSymbolicLink(&strLink);
// 删除设备对象
IoDeleteDevice(pDriverObj->DeviceObject);
DbgPrint("DriverUnload例程卸载完毕!\n");
}
// 处理IRP_MJ_CREATE、IRP_MJ_CLOSE功能代码
NTSTATUS DispatchCreateClose(PDEVICE_OBJECT pDevObj, PIRP pIrp)
{
pIrp->IoStatus.Status = STATUS_SUCCESS;
// pIrp->IoStatus.Information = 0;
// 完成此请求
IoCompleteRequest(pIrp, IO_NO_INCREMENT);
return STATUS_SUCCESS;
}
// I/O控制派遣例程
NTSTATUS DispatchIoctl(PDEVICE_OBJECT pDevObj, PIRP pIrp)
{
NTSTATUS status = STATUS_SUCCESS;
// 取得此IRP(pIrp)的I/O堆栈指针
PIO_STACK_LOCATION pIrpStack = IoGetCurrentIrpStackLocation(pIrp);
// 取得I/O控制代码
ULONG uIoControlCode = pIrpStack->Parameters.DeviceIoControl.IoControlCode;
// 取得I/O缓冲区指针和它的长度
PVOID pIoBuffer = pIrp->AssociatedIrp.SystemBuffer;
PVOID pInputBuffer = pIrpStack->Parameters.DeviceIoControl.Type3InputBuffer;
PVOID pOutputBuffer = pIrp->UserBuffer;
ULONG uInSize = pIrpStack->Parameters.DeviceIoControl.InputBufferLength;
// 响应用户的命令
switch(uIoControlCode)
{
case START_IP_HOOK: // 开始过滤
status = SetFilterFunction(FilterPackets);
break;
case STOP_IP_HOOK: // 停止过滤
status = SetFilterFunction(NULL);
break;
case ADD_FILTER: // 添加一个过滤规则
if(uInSize == sizeof(CIPFilter))
status = AddFilterToList((CIPFilter*)pIoBuffer);
else
status = STATUS_INVALID_DEVICE_REQUEST;
break;
case CLEAR_FILTER: // 释放过滤规则列表
ClearFilterList();
break;
case ADD_INFORMATION_ADDRESS:
__try
{
ProbeForRead(pInputBuffer, sizeof(void*), sizeof(pInputBuffer));
g_pPacket = (PacketInfo*)pInputBuffer;
DbgPrint("传输封包信息成功!g_pPacket: 0x%.8X\n", g_pPacket);
}
__except(EXCEPTION_EXECUTE_HANDLER)
{
unsigned long s = GetExceptionCode();
DbgPrint("设置封包数据抛出异常!\n异常代码:%d\n", s);
}
break;
default:
status = STATUS_INVALID_DEVICE_REQUEST;
break;
}
// 完成请求
pIrp->IoStatus.Status = status;
pIrp->IoStatus.Information = 0;
IoCompleteRequest(pIrp, IO_NO_INCREMENT);
return status;
}
///////////////////////////////////////////////////////////////////
//过滤列表
// 向过滤列表中添加一个过滤规则
NTSTATUS AddFilterToList(CIPFilter* pFilter)
{
// 为新的过滤规则申请内存空间
CFilterList* pNew = (CFilterList*)ExAllocatePool(NonPagedPool, sizeof(CFilterList));
if(pNew == NULL)
return STATUS_INSUFFICIENT_RESOURCES;
// 填充这块内存
RtlCopyMemory(&pNew->ipf, pFilter, sizeof(CIPFilter));
// 连接到过滤列表中
pNew->pNext = g_pHeader;
g_pHeader = pNew;
return STATUS_SUCCESS;
}
// 清除过滤列表
void ClearFilterList()
{
CFilterList* pNext;
// 释放过滤列表占用的所有内存
while(g_pHeader != NULL)
{
pNext = g_pHeader->pNext;
// 释放内存
ExFreePool(g_pHeader);
g_pHeader = pNext;
}
}
////////////////////////////////////////////////////
//过滤函数
// 过滤钩子回调函数
PF_FORWARD_ACTION FilterPackets(
unsigned char *PacketHeader,
unsigned char *Packet,
unsigned int PacketLength,
unsigned int RecvInterfaceIndex,
unsigned int SendInterfaceIndex,
IPAddr RecvLinkNextHop,
IPAddr SendLinkNextHop)
{
// 提取IP头
IPHeader* pIPHdr = (IPHeader*)PacketHeader;
TCPHeader* pTCPHdr = (TCPHeader*)Packet;
UDPHeader* pUDPHdr = (UDPHeader*)Packet;
if(g_pPacket!=0)
{
///////////////////////////////////////////////////////////////////////
//设置info
PacketInfo info;
info.protocol = pIPHdr->ipProtocol; // 使用的协议
info.sourceIP = pIPHdr->ipSource; // 源IP地址
info.destinationIP = pIPHdr->ipDestination; // 目标IP地址
if(pIPHdr->ipProtocol==6) // TCP协议
{
info.sourcePort = pTCPHdr->sourcePort; // 源端口号
info.destinationPort = pTCPHdr->destinationPort; // 目的端口号
}
else if(pIPHdr->ipProtocol==17) // UDP协议
{
info.sourcePort = pUDPHdr->sourcePort; // 源端口号
info.destinationPort = pUDPHdr->destinationPort; // 目的端口号
}
info.bRecv = (RecvInterfaceIndex==INVALID_PF_IF_INDEX)?FALSE:TRUE;
//////////////////////////////////////////////////////////////////////////
//复制内存工作
if(g_nPacketCount<g_nPacketMaxCount-1)
{
RtlCopyMemory(g_pPacket+g_nPacketCount, &info, sizeof(PacketInfo));
g_nPacketCount++;
}
else
{
RtlCopyMemory(g_pPacket, g_pPacket+g_nPacketMaxCount-1,
sizeof(PacketInfo)*(g_nPacketMaxCount-1));
RtlCopyMemory(g_pPacket+g_nPacketCount, &info, sizeof(PacketInfo));
}
////////////////////////////////////////////////////////////////////////////
//打印工作
char s[50], d[50];
unsigned short source, des;
inet_ntoa_krnl(info.sourceIP, s);
inet_ntoa_krnl(info.destinationIP, d);
ntohs(info.sourcePort, &source);
ntohs(info.destinationPort, &des);
DbgPrint("[*]传输数据成功!地址: 0x%.8X\n", g_pPacket+g_nPacketCount);
DbgPrint((info.bRecv)?"[接受]":"[发送]");
DbgPrint("Protocol: %u, Source: %s:%u, Destination: %s:%u",
info.protocol, s, d, source ,des);
}
if(pIPHdr->ipProtocol == 6) // 是TCP协议?
{
// 我们接受所有已经建立连接的TCP封包
if(!(((TCPHeader*)Packet)->flags & 0x02))
{
return PF_FORWARD;
}
}
// 与过滤规则相比较,决定采取的行动
CFilterList* pList = g_pHeader;
while(pList != NULL)
{
// 比较协议
if(pList->ipf.protocol == 0 || pList->ipf.protocol == pIPHdr->ipProtocol)
{
// 查看源IP地址
if(pList->ipf.sourceIP != 0 &&
(pList->ipf.sourceIP & pList->ipf.sourceMask) != pIPHdr->ipSource)
{
pList = pList->pNext;
continue;
}
// 查看目标IP地址
if(pList->ipf.destinationIP != 0 &&
(pList->ipf.destinationIP & pList->ipf.destinationMask) != pIPHdr->ipDestination)
{
pList = pList->pNext;
continue;
}
// 如果是TCP封包,查看端口号
if(pIPHdr->ipProtocol == 6)
{
//TCPHeader* pTCPHdr = (TCPHeader*)Packet;
if(pList->ipf.sourcePort == 0 || pList->ipf.sourcePort == pTCPHdr->sourcePort)
{
if(pList->ipf.destinationPort == 0
|| pList->ipf.destinationPort == pTCPHdr->destinationPort)
{
// 现在决定如何处理这个封包
if(pList->ipf.bDrop)
return PF_DROP;
else
return PF_FORWARD;
}
}
}
// 如果是UDP封包,查看端口号
else if(pIPHdr->ipProtocol == 17)
{
//UDPHeader* pUDPHdr = (UDPHeader*)Packet;
if(pList->ipf.sourcePort == 0 || pList->ipf.sourcePort == pUDPHdr->sourcePort)
{
if(pList->ipf.destinationPort == 0
|| pList->ipf.destinationPort == pUDPHdr->destinationPort)
{
// 现在决定如何处理这个封包
if(pList->ipf.bDrop)
return PF_DROP;
else
return PF_FORWARD;
}
}
}
else
{
// 对于其它封包,我们直接处理
if(pList->ipf.bDrop)
return PF_DROP;
else
return PF_FORWARD;
}
}
// 比较下一个规则
pList = pList->pNext;
}
// 我们接受所有没有注册的封包
return PF_FORWARD;
}
// 注册钩子回调函数
NTSTATUS SetFilterFunction(PacketFilterExtensionPtr filterFun)
{
NTSTATUS status = STATUS_SUCCESS;
// 取得IP过滤驱动设备对象。下面代码执行后,pDeviceObj变量将指向IP过滤驱动设备对象
PDEVICE_OBJECT pDeviceObj;
PFILE_OBJECT pFileObj;
// 初始化IP过滤驱动的名称
UNICODE_STRING ustrFilterDriver;
RtlInitUnicodeString(&ustrFilterDriver, L"\\Device\\IPFILTERDRIVER");
// 取得设备对象指针
status = IoGetDeviceObjectPointer(&ustrFilterDriver, FILE_ALL_ACCESS, &pFileObj, &pDeviceObj);
if(!NT_SUCCESS(status))
{
return status;
}
// 使用到IP过滤驱动中设备对象的指针创建一个IRP
// 填充PF_SET_EXTENSION_HOOK_INFO结构
PF_SET_EXTENSION_HOOK_INFO filterData;
filterData.ExtensionPointer = filterFun;
// 我们需要初始化一个事件对象。
// 构建IRP时需要使用这个事件内核对象,当IP过滤取得接受到此IRP,完成工作以后会将它置位
KEVENT event;
KeInitializeEvent(&event, NotificationEvent, FALSE);
// 为设备控制请求申请和构建一个IRP
PIRP pIrp;
IO_STATUS_BLOCK ioStatus;
pIrp = IoBuildDeviceIoControlRequest(IOCTL_PF_SET_EXTENSION_POINTER, // io control code
pDeviceObj,
(PVOID) &filterData,
sizeof(PF_SET_EXTENSION_HOOK_INFO),
NULL,
0,
FALSE,
&event,
&ioStatus);
if(pIrp == NULL)
{
// 如果不能申请空间,返回对应的错误代码
return STATUS_INSUFFICIENT_RESOURCES;
}
// 请求安装钩子回调函数
// 发送此IRP到IP过滤驱动
status = IoCallDriver(pDeviceObj, pIrp);
// 等待IP过滤驱动的通知
if(status == STATUS_PENDING)
{
KeWaitForSingleObject(&event, Executive, KernelMode, FALSE, NULL);
}
status = ioStatus.Status;
// 清除资源
if(pFileObj != NULL)
ObDereferenceObject(pFileObj);
return status;
}
void inet_ntoa_krnl(const unsigned long addr, char* out)
{
long tmp = 0;
long a[4] = { 0 };
memset(out, 0, 50);
tmp = addr;
tmp = tmp>>8<<8;
tmp = ~tmp;
tmp = tmp&addr;
a[0] = tmp;
tmp = addr;
tmp = tmp>>16<<8;
tmp = ~tmp;
tmp = tmp<<8;
tmp = (tmp&addr)>>8;
a[1] = tmp;
tmp = addr;
tmp = tmp>>24<<16;
tmp = ~tmp;
tmp = tmp<<8;
tmp = (tmp&addr)>>16;
a[2] = tmp;
tmp = addr;
tmp = tmp>>32<<24;
tmp = ~tmp;
tmp = tmp<<8;
tmp = (tmp&addr)>>24;
a[3] = tmp;
sprintf(out, "%d.%d.%d.%d", a[0], a[1], a[2], a[3]);
}
void ntohs(const unsigned short port, unsigned short* out)
{
unsigned short low,high;
memset(out, 0, sizeof(unsigned short));
low = port;
low = low>>8;
high = port;
high = high<<8;
*out = low+high;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -