📄 zinflate.cpp
字号:
// zinflate.cpp - written and placed in the public domain by Wei Dai// This is a complete reimplementation of the DEFLATE decompression algorithm.// It should not be affected by any security vulnerabilities in the zlib // compression library. In particular it is not affected by the double free bug// (http://www.kb.cert.org/vuls/id/368819).#include "pch.h"#include "zinflate.h"NAMESPACE_BEGIN(CryptoPP)struct CodeLessThan{ inline bool operator()(CryptoPP::HuffmanDecoder::code_t lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs) {return lhs < rhs.code;} // needed for MSVC .NET 2005 inline bool operator()(const CryptoPP::HuffmanDecoder::CodeInfo &lhs, const CryptoPP::HuffmanDecoder::CodeInfo &rhs) {return lhs.code < rhs.code;}};inline bool LowFirstBitReader::FillBuffer(unsigned int length){ while (m_bitsBuffered < length) { byte b; if (!m_store.Get(b)) return false; m_buffer |= (unsigned long)b << m_bitsBuffered; m_bitsBuffered += 8; } assert(m_bitsBuffered <= sizeof(unsigned long)*8); return true;}inline unsigned long LowFirstBitReader::PeekBits(unsigned int length){ bool result = FillBuffer(length); assert(result); return m_buffer & (((unsigned long)1 << length) - 1);}inline void LowFirstBitReader::SkipBits(unsigned int length){ assert(m_bitsBuffered >= length); m_buffer >>= length; m_bitsBuffered -= length;}inline unsigned long LowFirstBitReader::GetBits(unsigned int length){ unsigned long result = PeekBits(length); SkipBits(length); return result;}inline HuffmanDecoder::code_t HuffmanDecoder::NormalizeCode(HuffmanDecoder::code_t code, unsigned int codeBits){ return code << (MAX_CODE_BITS - codeBits);}void HuffmanDecoder::Initialize(const unsigned int *codeBits, unsigned int nCodes){ // the Huffman codes are represented in 3 ways in this code: // // 1. most significant code bit (i.e. top of code tree) in the least significant bit position // 2. most significant code bit (i.e. top of code tree) in the most significant bit position // 3. most significant code bit (i.e. top of code tree) in n-th least significant bit position, // where n is the maximum code length for this code tree // // (1) is the way the codes come in from the deflate stream // (2) is used to sort codes so they can be binary searched // (3) is used in this function to compute codes from code lengths // // a code in representation (2) is called "normalized" here // The BitReverse() function is used to convert between (1) and (2) // The NormalizeCode() function is used to convert from (3) to (2) if (nCodes == 0) throw Err("null code"); m_maxCodeBits = *std::max_element(codeBits, codeBits+nCodes); if (m_maxCodeBits > MAX_CODE_BITS) throw Err("code length exceeds maximum"); if (m_maxCodeBits == 0) throw Err("null code"); // count number of codes of each length SecBlockWithHint<unsigned int, 15+1> blCount(m_maxCodeBits+1); std::fill(blCount.begin(), blCount.end(), 0); unsigned int i; for (i=0; i<nCodes; i++) blCount[codeBits[i]]++; // compute the starting code of each length code_t code = 0; SecBlockWithHint<code_t, 15+1> nextCode(m_maxCodeBits+1); nextCode[1] = 0; for (i=2; i<=m_maxCodeBits; i++) { // compute this while checking for overflow: code = (code + blCount[i-1]) << 1 if (code > code + blCount[i-1]) throw Err("codes oversubscribed"); code += blCount[i-1]; if (code > (code << 1)) throw Err("codes oversubscribed"); code <<= 1; nextCode[i] = code; } if (code > (1 << m_maxCodeBits) - blCount[m_maxCodeBits]) throw Err("codes oversubscribed"); else if (m_maxCodeBits != 1 && code < (1 << m_maxCodeBits) - blCount[m_maxCodeBits]) throw Err("codes incomplete"); // compute a vector of <code, length, value> triples sorted by code m_codeToValue.resize(nCodes - blCount[0]); unsigned int j=0; for (i=0; i<nCodes; i++) { unsigned int len = codeBits[i]; if (len != 0) { code = NormalizeCode(nextCode[len]++, len); m_codeToValue[j].code = code; m_codeToValue[j].len = len; m_codeToValue[j].value = i; j++; } } std::sort(m_codeToValue.begin(), m_codeToValue.end()); // initialize the decoding cache m_cacheBits = STDMIN(9U, m_maxCodeBits); m_cacheMask = (1 << m_cacheBits) - 1; m_normalizedCacheMask = NormalizeCode(m_cacheMask, m_cacheBits); assert(m_normalizedCacheMask == BitReverse(m_cacheMask)); if (m_cache.size() != size_t(1) << m_cacheBits) m_cache.resize(1 << m_cacheBits); for (i=0; i<m_cache.size(); i++) m_cache[i].type = 0;}void HuffmanDecoder::FillCacheEntry(LookupEntry &entry, code_t normalizedCode) const{ normalizedCode &= m_normalizedCacheMask; const CodeInfo &codeInfo = *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode, CodeLessThan())-1); if (codeInfo.len <= m_cacheBits) { entry.type = 1; entry.value = codeInfo.value; entry.len = codeInfo.len; } else { entry.begin = &codeInfo; const CodeInfo *last = & *(std::upper_bound(m_codeToValue.begin(), m_codeToValue.end(), normalizedCode + ~m_normalizedCacheMask, CodeLessThan())-1); if (codeInfo.len == last->len) { entry.type = 2; entry.len = codeInfo.len; } else { entry.type = 3; entry.end = last+1; } }}inline unsigned int HuffmanDecoder::Decode(code_t code, /* out */ value_t &value) const{ assert(m_codeToValue.size() > 0); LookupEntry &entry = m_cache[code & m_cacheMask]; code_t normalizedCode; if (entry.type != 1) normalizedCode = BitReverse(code); if (entry.type == 0) FillCacheEntry(entry, normalizedCode); if (entry.type == 1) { value = entry.value; return entry.len; } else { const CodeInfo &codeInfo = (entry.type == 2) ? entry.begin[(normalizedCode << m_cacheBits) >> (MAX_CODE_BITS - (entry.len - m_cacheBits))] : *(std::upper_bound(entry.begin, entry.end, normalizedCode, CodeLessThan())-1); value = codeInfo.value; return codeInfo.len; }}bool HuffmanDecoder::Decode(LowFirstBitReader &reader, value_t &value) const{ reader.FillBuffer(m_maxCodeBits); unsigned int codeBits = Decode(reader.PeekBuffer(), value); if (codeBits > reader.BitsBuffered()) return false; reader.SkipBits(codeBits); return true;}// *************************************************************Inflator::Inflator(BufferedTransformation *attachment, bool repeat, int propagation) : AutoSignaling<Filter>(propagation) , m_state(PRE_STREAM), m_repeat(repeat), m_reader(m_inQueue){ Detach(attachment);}void Inflator::IsolatedInitialize(const NameValuePairs ¶meters){ m_state = PRE_STREAM; parameters.GetValue("Repeat", m_repeat); m_inQueue.Clear(); m_reader.SkipBits(m_reader.BitsBuffered());}void Inflator::OutputByte(byte b){ m_window[m_current++] = b; if (m_current == m_window.size()) { ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush); m_lastFlush = 0; m_current = 0; m_wrappedAround = true; }}void Inflator::OutputString(const byte *string, size_t length){ while (length) { size_t len = UnsignedMin(length, m_window.size() - m_current); memcpy(m_window + m_current, string, len); m_current += len; if (m_current == m_window.size()) { ProcessDecompressedData(m_window + m_lastFlush, m_window.size() - m_lastFlush); m_lastFlush = 0; m_current = 0; m_wrappedAround = true; } string += len; length -= len; } }void Inflator::OutputPast(unsigned int length, unsigned int distance){ size_t start; if (distance <= m_current) start = m_current - distance; else if (m_wrappedAround && distance <= m_window.size()) start = m_current + m_window.size() - distance; else throw BadBlockErr(); if (start + length > m_window.size()) { for (; start < m_window.size(); start++, length--) OutputByte(m_window[start]); start = 0; } if (start + length > m_current || m_current + length >= m_window.size()) { while (length--) OutputByte(m_window[start++]); } else { memcpy(m_window + m_current, m_window + start, length); m_current += length; }}size_t Inflator::Put2(const byte *inString, size_t length, int messageEnd, bool blocking){ if (!blocking) throw BlockingInputOnly("Inflator"); LazyPutter lp(m_inQueue, inString, length); ProcessInput(messageEnd != 0); if (messageEnd) if (!(m_state == PRE_STREAM || m_state == AFTER_END)) throw UnexpectedEndErr(); Output(0, NULL, 0, messageEnd, blocking); return 0;}bool Inflator::IsolatedFlush(bool hardFlush, bool blocking){ if (!blocking) throw BlockingInputOnly("Inflator"); if (hardFlush) ProcessInput(true);
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -