⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 filter.c

📁 非常使用的基于TDI驱动开发的应用程序过滤 防火墙的例子
💻 C
字号:
// -*- mode: C++; tab-width: 4; indent-tabs-mode: nil -*- (for GNU Emacs)
//
// $Id: filter.c,v 1.5 2002/12/03 13:32:17 dev Exp $

/*
 * Filtering related routines
 */

#include <ntddk.h>
#include <tdikrnl.h>
#include "sock.h"

#include "filter.h"
#include "memtrack.h"
#include "pid_pname.h"
#include "tdi_fw.h"

// size of cyclic queue for logging
#define REQUEST_QUEUE_SIZE	1024

/* rules chains (main (first entry) and process-related) */
static struct {
	struct {
		struct		flt_rule *head;
		struct		flt_rule *tail;
		char		*pname;				// name of process
	} chain[MAX_CHAINS_COUNT];
	KSPIN_LOCK	guard;
} g_rules;

/* logging request queue */
static struct {
	struct		flt_request *data;
	KSPIN_LOCK	guard;
	ULONG		head;	/* write to head */
	ULONG		tail;	/* read from tail */
	HANDLE		event_handle;
	PKEVENT		event;
} g_queue;

// init
NTSTATUS
filter_init(void)
{
	NTSTATUS status;
	int i;

	pid_pname_init();

	/* rules chain */
	
	KeInitializeSpinLock(&g_rules.guard);
	for (i = 0; i < MAX_CHAINS_COUNT; i++) {
		g_rules.chain[i].head = g_rules.chain[i].tail = NULL;
		g_rules.chain[i].pname = NULL;
	}

	/* request queue */
	
	KeInitializeSpinLock(&g_queue.guard);

	g_queue.data = (struct flt_request *)malloc_np(sizeof(struct flt_request) * REQUEST_QUEUE_SIZE);
	if (g_queue.data == NULL) {
		KdPrint(("[tdi_fw] filter_init: malloc_np!\n"));
		return STATUS_INSUFFICIENT_RESOURCES;
	}

	memset(g_queue.data, 0, sizeof(struct flt_request) * REQUEST_QUEUE_SIZE);

	g_queue.head = g_queue.tail = 0;

	return STATUS_SUCCESS;
}

// init for user part starting
NTSTATUS
filter_init_2(void)
{
	NTSTATUS status;

	if (g_queue.event_handle == NULL) {
		UNICODE_STRING str;
		OBJECT_ATTRIBUTES oa;

		RtlInitUnicodeString(&str, L"\\BaseNamedObjects\\tdi_fw_request");
		InitializeObjectAttributes(&oa, &str, 0, NULL, NULL);

		status = ZwCreateEvent(&g_queue.event_handle, EVENT_ALL_ACCESS, &oa, SynchronizationEvent, FALSE);
		if (status != STATUS_SUCCESS) {
			KdPrint(("[tdi_fw] filter_init_2: ZwCreateEvent: 0x%x\n", status));
			return status;
		}

	}

	if (g_queue.event == NULL) {
		status = ObReferenceObjectByHandle(g_queue.event_handle, EVENT_ALL_ACCESS, NULL, KernelMode,
			&g_queue.event, NULL);
		if (status != STATUS_SUCCESS) {
			KdPrint(("[tdi_fw] filter_init_2: ObReferenceObjectByHandle: 0x%x\n", status));
			return status;
		}
	}

	return STATUS_SUCCESS;
}

// cleanup for user part
void
filter_free_2(void)
{
	if (&g_queue.event != NULL) {
		ObDereferenceObject(&g_queue.event);
		g_queue.event = NULL;
	}
	if (g_queue.event_handle != NULL) {
		ZwClose(g_queue.event_handle);
		g_queue.event_handle = NULL;
	}
}

// free
void
filter_free(void)
{
	KIRQL irql;
	struct plist_entry *ple;
	int i;

	// clear all chains
	for (i = 0; i < MAX_CHAINS_COUNT; i++)
		clear_flt_chain(i);

	/* clear request queue */
	KeAcquireSpinLock(&g_queue.guard, &irql);
	for (i = 0; i < REQUEST_QUEUE_SIZE; i++)
		if (g_queue.data[i].pname != NULL)
			free(g_queue.data[i].pname);
	free(g_queue.data);
	KeReleaseSpinLock(&g_queue.guard, irql);

	pid_pname_free();
}

// quick filter
int
quick_filter(struct flt_request *request, struct flt_rule *rule)
{
    const struct sockaddr_in *from, *to;
	struct flt_rule *r;
	struct plist_entry *ple;
	KIRQL irql;
	int result;

	// not IP
    if (request->addr.len != sizeof(struct sockaddr_in) ||
        request->addr.from.sa_family != AF_INET ||
        request->addr.to.sa_family != AF_INET)
    {
		KdPrint(("[tdi_fw] quick_filter: not ip addr!\n"));
        return FILTER_DENY;
    }

    from = (const struct sockaddr_in *)&request->addr.from;
    to = (const struct sockaddr_in *)&request->addr.to;

	// default behavior
	result = FILTER_ALLOW;
	if (rule != NULL) {
		memset(rule, 0, sizeof(*rule));
		rule->result = result;
	}

	// quick filter
	KeAcquireSpinLock(&g_rules.guard, &irql);

	// go through rules
	for (r = g_rules.chain[pid_pname_get_context(request->pid)].head; r != NULL; r = r->next)
		// Can anybody understand it?
		if (r->proto == request->proto &&
			r->direction == request->direction &&
			(r->addr_from & r->mask_from) == (from->sin_addr.s_addr & r->mask_from) &&
			(r->addr_to & r->mask_to) == (to->sin_addr.s_addr & r->mask_to) &&
			(r->port_from == 0 || ((r->port2_from == 0) ? (r->port_from == from->sin_port) :
				(ntohs(from->sin_port) >= ntohs(r->port_from) && ntohs(from->sin_port) <= ntohs(r->port2_from)))) &&
			(r->port_to == 0 || ((r->port2_to == 0) ? (r->port_to == to->sin_port) :
				(ntohs(to->sin_port) >= ntohs(r->port_to) && ntohs(to->sin_port) <= ntohs(r->port2_to)))))
		{
			result = r->result;
			KdPrint(("[tdi_fw] quick_filter: found rule with result: %d\n", result));
			
			if (rule != NULL) {
				memcpy(rule, r, sizeof(*rule));
				
				rule->next = NULL;	// useless field
			}

			break;
		}


	KeReleaseSpinLock(&g_rules.guard, irql);

	request->result = result;
	return result;
}

// write request to request queue
BOOLEAN
log_request(struct flt_request *request)
{
	KIRQL irql, irql2;
	ULONG next_head;
	char pname_buf[256], *pname;
	struct plist_entry *ple;

	if (g_got_control == 0)		// don't log - no control app
		return FALSE;

	KeAcquireSpinLock(&g_queue.guard, &irql);

	next_head = (g_queue.head + 1) % REQUEST_QUEUE_SIZE;
	
	if (next_head == g_queue.tail) {
		// queue overflow: reject one entry from tail
		KdPrint(("[tdi_fw] log_request: queue overflow!\n"));
		
		request->log_skipped = g_queue.data[g_queue.tail].log_skipped + 1;
		g_queue.tail = (g_queue.tail + 1) % REQUEST_QUEUE_SIZE;
	} else
		request->log_skipped = 0;

	memcpy(&g_queue.data[g_queue.head], request, sizeof(struct flt_request));

	// try to get process name
	pname = NULL;
	if (pid_pname_resolve(request->pid, pname_buf, sizeof(pname_buf)) ) {
		KdPrint(("[tdi_fw] log_request: pid:%u; pname:%s\n",
			request->pid, pname_buf));

		// ala strdup()
		pname = (char *)malloc_np(strlen(pname_buf) + 1);
		if (pname != NULL)
			strcpy(pname, pname_buf);
		else
			KdPrint(("[tdi_fw] log_request: malloc_np!\n"));
	}

	g_queue.data[g_queue.head].pname = pname;
	g_queue.head = next_head;

	KeReleaseSpinLock(&g_queue.guard, irql);

	// signal to user app
	if (g_queue.event != NULL)
		KeSetEvent(g_queue.event, IO_NO_INCREMENT, FALSE);
	
	return TRUE;
}

// read requests from log queue to buffer
ULONG
get_request(char *buf, ULONG buf_size)
{
	ULONG result = 0;
	KIRQL irql;

	// sanity check
	if (buf_size < sizeof(struct flt_request))
		return 0;

	KeAcquireSpinLock(&g_queue.guard, &irql);

	while (g_queue.head != g_queue.tail) {
		int pname_size;

		if (g_queue.data[g_queue.tail].pname != NULL)
			pname_size = strlen(g_queue.data[g_queue.tail].pname) + 1;
		else
			pname_size = 0;

		if (buf_size < sizeof(struct flt_request) + pname_size)
			return result;

		memcpy(buf, &g_queue.data[g_queue.tail], sizeof(struct flt_request));

		if (g_queue.data[g_queue.tail].pname != NULL) {
			((struct flt_request *)buf)->struct_size += pname_size;
			
			strcpy(buf + sizeof(struct flt_request), g_queue.data[g_queue.tail].pname);
			
			free(g_queue.data[g_queue.tail].pname);
			g_queue.data[g_queue.tail].pname = NULL;
		}

		result += sizeof(struct flt_request) + pname_size;
		buf += sizeof(struct flt_request) + pname_size;
		buf_size -= sizeof(struct flt_request) + pname_size;

		g_queue.tail = (g_queue.tail + 1) % REQUEST_QUEUE_SIZE;
	}
	
	KdPrint(("[tdi_fw] get_request: copied %u bytes\n", result));

	KeReleaseSpinLock(&g_queue.guard, irql);
	return result;
}

// add rule to rules chain
NTSTATUS
add_flt_rule(int chain, const struct flt_rule *rule)
{
	NTSTATUS status;
	struct flt_rule *new_rule;
	KIRQL irql;

	// sanity check
	if (chain < 0 && chain >= MAX_CHAINS_COUNT)
		return STATUS_INVALID_PARAMETER_1;
	
	KeAcquireSpinLock(&g_rules.guard, &irql);

	new_rule = (struct flt_rule *)malloc_np(sizeof(struct flt_rule));
	if (new_rule == NULL) {
		KdPrint(("[tdi_fw] add_flt_rule: malloc_np\n"));
		status = STATUS_INSUFFICIENT_RESOURCES;
		goto done;
	}

	memcpy(new_rule, rule, sizeof(*new_rule));

	// append
	new_rule->next = NULL;

	if (g_rules.chain[chain].tail == NULL) {
		g_rules.chain[chain].head = new_rule;
		g_rules.chain[chain].tail = new_rule;
	} else {
		g_rules.chain[chain].tail->next = new_rule;
		g_rules.chain[chain].tail = new_rule;
	}

	status = STATUS_SUCCESS;

done:
	KeReleaseSpinLock(&g_rules.guard, irql);
	return status;
}

// clear rules chain
NTSTATUS
clear_flt_chain(int chain)
{
	struct flt_rule *rule;
	KIRQL irql;

	// sanity check
	if (chain < 0 && chain >= MAX_CHAINS_COUNT)
		return STATUS_INVALID_PARAMETER_1;
	
	/* rules chain */
	KeAcquireSpinLock(&g_rules.guard, &irql);

	for (rule = g_rules.chain[chain].head; rule != NULL;) {
		struct flt_rule *rule2 = rule->next;
		free(rule);
		rule = rule2;
	}

	g_rules.chain[chain].head = NULL;
	g_rules.chain[chain].tail = NULL;

	KeReleaseSpinLock(&g_rules.guard, irql);
	return STATUS_SUCCESS;
}

// set process name for chain
NTSTATUS
set_chain_pname(int chain, char *pname)
{
	KIRQL irql;
	NTSTATUS status;

	// sanity check
	if (chain < 0 || chain >= MAX_CHAINS_COUNT)
		return STATUS_INVALID_PARAMETER_1;

	KdPrint(("[tdi_fw] set_chain_pname: setting name %s for chain %d\n", pname, chain));

	KeAcquireSpinLock(&g_rules.guard, &irql);

	if (g_rules.chain[chain].pname != NULL)
		free(g_rules.chain[chain].pname);

	g_rules.chain[chain].pname = (char *)malloc_np(strlen(pname) + 1);
	if (g_rules.chain[chain].pname != NULL) {
		// copy pname
		strcpy(g_rules.chain[chain].pname, pname);
		status = STATUS_SUCCESS;
	} else
		status = STATUS_INSUFFICIENT_RESOURCES;

	KeReleaseSpinLock(&g_rules.guard, irql);
	return status;
}

// set result of process name by pid resolving
NTSTATUS
set_pid_pname(ULONG pid, char *pname)
{
	KIRQL irql;
	int i, chain = 0;

	KdPrint(("[tdi_fw] set_pid_pname: setting pname %s for pid %u\n", pname, pid));
	
	KeAcquireSpinLock(&g_rules.guard, &irql);
	for (i = 0; i < MAX_CHAINS_COUNT; i++)
		if (g_rules.chain[i].pname != NULL &&
			_stricmp(pname, g_rules.chain[i].pname) == 0) {
	
			KdPrint(("[tdi_fw] set_pid_pname: found chain %d\n", i));
			chain = i;

			break;
		}
	KeReleaseSpinLock(&g_rules.guard, irql);

	return pid_pname_set(pid, pname, chain);
}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -