⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 zinflate.cpp

📁 加密函数库:包括多种加密解密算法,数字签名,散列算法
💻 CPP
📖 第 1 页 / 共 2 页
字号:
// zinflate.cpp - written and placed in the public domain by Wei Dai

// This is a complete reimplementation of the DEFLATE decompression algorithm.
// It should not be affected by any security vulnerabilities in the zlib 
// compression library. In particular it is not affected by the double free bug
// (http://www.kb.cert.org/vuls/id/368819).

#include "pch.h"
#include "zinflate.h"

NAMESPACE_BEGIN(CryptoPP)

struct CodeLessThan
{
	inline bool operator()(const CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs)
		{return lhs < rhs.code;}
};

inline bool LowFirstBitReader::FillBuffer(unsigned int length)
{
	while (m_bitsBuffered < length)
	{
		byte b;
		if (!m_store.Get(b))
			return false;
		m_buffer |= (unsigned long)b << m_bitsBuffered;
		m_bitsBuffered += 8;
	}
	assert(m_bitsBuffered <= sizeof(unsigned long)*8);
	return true;
}

inline unsigned long LowFirstBitReader::PeekBits(unsigned int length)
{
	bool result = FillBuffer(length);
	assert(result);
	return m_buffer & (((unsigned long)1 << length) - 1);
}

inline void LowFirstBitReader::SkipBits(unsigned int length)
{
	assert(m_bitsBuffered >= length);
	m_buffer >>= length;
	m_bitsBuffered -= length;
}

inline unsigned long LowFirstBitReader::GetBits(unsigned int length)
{
	unsigned long result = PeekBits(length);
	SkipBits(length);
	return result;
}

inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits)
{
	return code << (MAX_CODE_BITS - codeBits);
}

void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes)
{
	// the Huffman codes are represented in 3 ways in this code:
	//
	// 1. most significant code bit (i.e. top of code tree) in the least significant bit position
	// 2. most significant code bit (i.e. top of code tree) in the most significant bit position
	// 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position,
	//    where n is the maximum code length for this code tree
	//
	// (1) is the way the codes come in from the deflate stream
	// (2) is used to sort codes so they can be binary searched
	// (3) is used in this function to compute codes from code lengths
	//
	// a code in representation (2) is called "normalized" here
	// The BitReverse() function is used to convert between (1) and (2)
	// The NormalizeCode() function is used to convert from (3) to (2)

	if (nCodes == 0)
		throw Err("null code");

	m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes);

	if (m_maxCodeBits > MAX_CODE_BITS)
		throw Err("code length exceeds maximum");

	if (m_maxCodeBits == 0)
		throw Err("null code");

	// count number of codes of each length
	SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1);
	std::fill(blCount.begin(), blCount.end(), 0);
	unsigned int i;
	for (i=0; i<nCodes; i++)
		blCount[codeBits[i]]++;

	// compute the starting code of each length
	code_t code = 0;
	SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1);
	nextCode[1] = 0;
	for (i=2; i<=m_maxCodeBits; i++)
	{
		// compute this while checking for overflow: code = (code + blCount[i-1]) << 1
		if (code > code + blCount[i-1])
			throw Err("codes oversubscribed");
		code += blCount[i-1];
		if (code > (code << 1))
			throw Err("codes oversubscribed");
		code <<= 1;
		nextCode[i] = code;
	}

	if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
		throw Err("codes oversubscribed");
	else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits])
		throw Err("codes incomplete");

	// compute a vector of <code, length, value> triples sorted by code
	m_codeToValue.resize(nCodes - blCount[0]);
	unsigned int j=0;
	for (i=0; i<nCodes; i++) 
	{
		unsigned int len = codeBits[i];
		if (len != 0)
		{
			code = NormalizeCode(nextCode[len]++, len);
			m_codeToValue[j].code = code;
			m_codeToValue[j].len = len;
			m_codeToValue[j].value = i;
			j++;
		}
	}
	std::sort(m_codeToValue.begin(), m_codeToValue.end());

	// initialize the decoding cache
	m_cacheBits = STDMIN(9U, m_maxCodeBits);
	m_cacheMask = (1 << m_cacheBits) - 1;
	m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits);
	assert(m_normalizedCacheMask == BitReverse(m_cacheMask));

	if (m_cache.size() != 1 << m_cacheBits)
		m_cache.resize(1 << m_cacheBits);

	for (i=0; i<m_cache.size(); i++)
		m_cache[i].type = 0;
}

void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const
{
	normalizedCode &= m_normalizedCacheMask;
	const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1);
	if (codeInfo.len <= m_cacheBits)
	{
		entry.type = 1;
		entry.value = codeInfo.value;
		entry.len = codeInfo.len;
	}
	else
	{
		entry.begin = &codeInfo;
		const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1);
		if (codeInfo.len == last->len)
		{
			entry.type = 2;
			entry.len = codeInfo.len;
		}
		else
		{
			entry.type = 3;
			entry.end = last+1;
		}
	}
}

inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const
{
	assert(m_codeToValue.size() > 0);
	LookupEntry &entry = m_cache[code & m_cacheMask];

	code_t normalizedCode;
	if (entry.type != 1)
		normalizedCode = BitReverse(code);

	if (entry.type == 0)
		FillCacheEntry(entry, normalizedCode);

	if (entry.type == 1)
	{
		value = entry.value;
		return entry.len;
	}
	else
	{
		const CodeInfo &codeInfo = (entry.type == 2)
			? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))]
			: *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1);
		value = codeInfo.value;
		return codeInfo.len;
	}
}

bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const
{
	reader.FillBuffer(m_maxCodeBits);
	unsigned int codeBits = Decode(reader.PeekBuffer(), value);
	if (codeBits > reader.BitsBuffered())
		return false;
	reader.SkipBits(codeBits);
	return true;
}

// *************************************************************

Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation)
	: AutoSignaling<Filter>(attachment, propagation)
	, m_state(PRE_STREAM), m_repeat(repeat)
	, m_decodersInitializedWithFixedCodes(false), m_reader(m_inQueue)
{
}

void Inflator::IsolatedInitialize(const NameValuePairs &parameters)
{
	m_state = PRE_STREAM;
	parameters.GetValue("Repeat", m_repeat);
	m_inQueue.Clear();
	m_reader.SkipBits(m_reader.BitsBuffered());
}

inline void Inflator::OutputByte(byte b)
{
	m_window[m_current++] = b;
	if (m_current == m_window.size())
	{
		ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush);
		m_lastFlush = 0;
		m_current = 0;
	}
	if (m_maxDistance < m_window.size())
		m_maxDistance++;
}

void Inflator::OutputString(const byte *string, unsigned int length)
{
	while (length--)
		OutputByte(*string++);
}

void Inflator::OutputPast(unsigned int length, unsigned int distance)
{
	if (distance > m_maxDistance)
		throw BadBlockErr();
	unsigned int start;
	if (m_current > distance)
		start = m_current - distance;
	else
		start = m_current + m_window.size() - distance;

	if (start + length > m_window.size())
	{
		for (; start < m_window.size(); start++, length--)
			OutputByte(m_window[start]);
		start = 0;
	}

	if (start + length > m_current || m_current + length >= m_window.size())
	{
		while (length--)
			OutputByte(m_window[start++]);
	}
	else
	{
		memcpy(m_window + m_current, m_window + start, length);
		m_current += length;
		m_maxDistance = STDMIN((unsigned int)m_window.size(), m_maxDistance + length);
	}
}

unsigned int Inflator::Put2(const byte *inString, unsigned int length, int messageEnd, bool blocking)
{
	if (!blocking)
		throw BlockingInputOnly("Inflator");

	LazyPutter lp(m_inQueue, inString, length);
	ProcessInput(messageEnd != 0);

	if (messageEnd)
		if (!(m_state == PRE_STREAM || m_state == AFTER_END))
			throw UnexpectedEndErr();

	Output(0, NULL, 0, messageEnd, blocking);
	return 0;
}

bool Inflator::IsolatedFlush(bool hardFlush, bool blocking)

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -