📄 aricacm.c
字号:
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/* $Copyright Issue
* ----------------
* Most of the code on arithemtic coding are modified from
* I.H. Witten, R.M. Neal, and J.G. Cleary, "Arithmetic coding for data
* compression," Communnication ACM, vol 30, pp. 520-540, June 1987.
*
* The statistical table is implemented using Binary Index Tree as described in
* P. M. Fenwick, "A new data structure for cumulative frequency tables,"
* Softw. Pract. Exper. 24, 3 (Mar. 1994), 327-336.
*
* The idea of having the arithemetic coder to read/write bits directly was
* from G. Davis code in "Wavelet Compression Construction Kit". His code
* can only do write but has problem with synchronisation when doing read.
* I fixed that with a virtual table approach.
*
* Mow-Song, Ng 2/9/2002
* msng@mmu.edu.my
* http://www.pesona.mmu.edu.my/~msng
*
* I do not claim copyright to the code, but if you use them or modify them,
* please drop me a mail.
*
*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
#include "aricacm.h"
//#define __F_SCALE_
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/* Source context */
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
Context *ContextAlloc(void)
{
Context *context;
if((context=(Context *)malloc(sizeof(Context)))==NULL){
return NULL;
}
context->Tree=NULL;
context->nSymbols=0;
context->TotalFreq=0;
context->p2half=0;
context->MaxCount = 0;
return context;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
int ContextInitialize(Context *context, int nSymbols, int MaxCount, int set)
{
int i;
context->nSymbols=nSymbols;
context->TotalFreq=0;
i=1;
while(context->nSymbols > i){
i<<=1;
}
context->p2half=i>>1;
if (context->Tree!=NULL){
free(context->Tree);
}
if ((context->Tree = (int *)calloc(context->nSymbols, sizeof(int)))==NULL){
free(context);
return 0;
}
if (set==1){
for (i=0; i<context->nSymbols; i++){
ContextPutValue(context, 1, i);
}
}
context->MaxCount = MaxCount;
return 1;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ContextDealloc(Context *context)
{
if (context!=NULL){
free(context->Tree);
free(context);
}
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
int ContextGetCumul(Context *context, int ix)
{
unsigned int sum;
if (ix<0){
return 0;
}
sum = context->Tree[0];
while(ix>0){
sum = sum+context->Tree[ix];
ix = ix & (ix-1);
}
return sum;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ContextPutValue(Context *context, int val, int ix)
{
assert(ix>=0);
if (ix==0){
context->Tree[0]+=val;
}
else{
while(ix < context->nSymbols){
context->Tree[ix] = context->Tree[ix] + val;
ix = 2*ix - (ix & (ix-1));
}
}
context->TotalFreq+=val;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
int ContextGetProb(Context *context, int ix)
{
int val, parent;
assert(ix>=0 && ix<context->nSymbols);
val = context->Tree[ix];
if (ix>0){
parent= ix & (ix-1);
ix = ix-1;
while (parent != ix ){
val = val - context->Tree[ix];
ix = ix & (ix -1);
}
}
return val;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
int ContextGetSymbol(Context *context, int cumFreq)
{
int baseIx, testIx, half;
if (cumFreq++<context->Tree[0]){
return 0;
}
cumFreq-=context->Tree[0];
baseIx=0;
half = context->p2half;
while (half>0){
testIx = baseIx+half;
if (cumFreq>context->Tree[testIx] && testIx < context->nSymbols){
baseIx=testIx;
cumFreq-=context->Tree[testIx];
}
half>>=1;
}
return baseIx+1;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ContextScaleDown(Context *context)
{
int i;
for (i=context->nSymbols-1; i>=0; i--){
ContextPutValue(context, -ContextGetProb(context, i)/2, i);
}
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ContextUpdate(Context *context, int val, int ix)
{
if (context->TotalFreq > context->MaxCount){
#ifdef __F_SCALE_
fprintf(stderr, "Scale down\n");
#endif
ContextScaleDown(context);
}
ContextPutValue(context, val, ix);
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
double ContextGetCost(Context *context, int symbol)
{
double cost;
cost = -ONELOG2*log( ((double)ContextGetProb(context, symbol))/context->TotalFreq);
return cost;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ContextPrintTable(Context *context, FILE *fp)
{
int i;
fprintf(fp, "--------------------------------------------------------------------\n");
for (i=0; i<context->nSymbols; i++){
fprintf(fp, "index:%5d freq:%7d cumul freq:%8d stored:%7d\n",
i, ContextGetProb(context, i), ContextGetCumul(context, i), context->Tree[i]);
}
fprintf(fp, "Total: %d\n", context->TotalFreq);
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/* Arithmetic encoder */
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
ArithEncoder *ArithEncoderAlloc(BIT_FILE *bs)
{
ArithEncoder *encoder;
if ((encoder=(ArithEncoder *)malloc(sizeof(ArithEncoder)))==NULL){
return NULL;
}
encoder->bs=bs;
return encoder;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ArithEncoderDealloc(ArithEncoder *encoder)
{
if (encoder!=NULL){
free(encoder);
}
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ArithEncoderStart(ArithEncoder *encoder)
{
encoder->low=0;
encoder->high=TopValue;
encoder->range=encoder->high-encoder->low+1;
encoder->BitsToFollow=0;
encoder->nBitsOutput=0;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ArithEncoderEncode(ArithEncoder *encoder, int countLeft, int count, int countTotal)
{
encoder->high= encoder->low+(encoder->range*count)/countTotal - 1;
encoder->low = encoder->low+(encoder->range*countLeft)/countTotal;
for (;;){
if (encoder->high<Half){
ArithEncoderBitPlusFollow(encoder, 0);
}
else if (encoder->low>=Half){
ArithEncoderBitPlusFollow(encoder, 1);
encoder->low-=Half;
encoder->high-=Half;
}
else if (encoder->low>=FirstQtr &&
encoder->high< ThirdQtr){
encoder->BitsToFollow++;
encoder->low-=FirstQtr;
encoder->high-=FirstQtr;
}
else{
break;
}
encoder->low=2*encoder->low;
encoder->high=2*encoder->high + 1;
}
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ArithEncoderBitPlusFollow(ArithEncoder *encoder, int bit)
{
if (encoder->bs!=NULL){
OutputBit(encoder->bs, bit);
encoder->nBitsOutput++;
while(encoder->BitsToFollow>0){
OutputBit(encoder->bs, !bit);
encoder->nBitsOutput++;
encoder->BitsToFollow-=1;
}
}
else{
encoder->nBitsOutput++;
encoder->nBitsOutput++;
while(encoder->BitsToFollow>0){
encoder->nBitsOutput++;
encoder->nBitsOutput++;
encoder->BitsToFollow-=1;
}
}
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ArithEncoderDone(ArithEncoder *encoder)
{
encoder->BitsToFollow+=1;
if (encoder->low<FirstQtr){
ArithEncoderBitPlusFollow(encoder, 0);
}
else{
ArithEncoderBitPlusFollow(encoder, 1);
}
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
int ArithEncoderNBitsOutput(ArithEncoder *encoder)
{
// The actual bits "output" at the moment should includes the
// BitsToFollow and the two bits at the end of encoding
return (encoder->nBitsOutput + encoder->BitsToFollow + 2);
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/* Arithmetic decoder */
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
ArithDecoder *ArithDecoderAlloc(BIT_FILE *bs)
{
ArithDecoder *decoder;
if ((decoder=(ArithDecoder *)malloc(sizeof(ArithDecoder)))==NULL){
return NULL;
}
decoder->bs=bs;
return decoder;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ArithDecoderDealloc(ArithDecoder *decoder)
{
if (decoder!=NULL){
free(decoder);
}
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ArithDecoderStart(ArithDecoder *decoder)
{
int i;
decoder->eof = 0;
decoder->value = 0;
for(i=0; i<CodeValueBits; i++){
decoder->value = 2*decoder->value+InputBit(decoder->bs, &decoder->eof);
if (decoder->eof){
return;
}
}
decoder->low = 0;
decoder->high = TopValue;
decoder->range = decoder->high - decoder->low + 1;
decoder->nBitsInput = CodeValueBits;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void ArithDecoderDecode(ArithDecoder *decoder, int countLeft, int count, int countTotal)
{
decoder->high=decoder->low+(decoder->range*count)/countTotal-1;
decoder->low =decoder->low+(decoder->range*countLeft)/countTotal;
for (;;){
if (decoder->high<Half){
}
else if (decoder->low>=Half){
decoder->value-=Half;
decoder->low-=Half;
decoder->high-=Half;
}
else if (decoder->low>=FirstQtr &&
decoder->high<ThirdQtr){
decoder->value-=FirstQtr;
decoder->low-=FirstQtr;
decoder->high-=FirstQtr;
}
else{
break;
}
decoder->low = 2*decoder->low;
decoder->high= 2*decoder->high + 1;
decoder->value= 2*decoder->value + InputBit(decoder->bs, &decoder->eof);
/* there may be some problem here */
if (decoder->eof){
break;
}
decoder->nBitsInput++;
}
return;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
int ArithDecoderNBitsInput(ArithDecoder *decoder)
{
// not much to comppute ;-)
return decoder->nBitsInput;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
/* BasicCoder - from G.Davis code */
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
BasicCoder *BasicCoderAlloc(int nSymbols, int MaxCount)
{
BasicCoder *coder;
if ((coder=(BasicCoder *)malloc(sizeof(BasicCoder)))==NULL){
return NULL;
}
coder->nSymbols = nSymbols;
coder->context = ContextAlloc();
ContextInitialize(coder->context, nSymbols, MaxCount, 1);
coder->EndOfStreamSymbol = coder->context->nSymbols-1;
return coder;
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
void BasicCoderDealloc(BasicCoder *coder)
{
if (coder != NULL){
ContextDealloc(coder->context);
free(coder);
}
}
/*----------------------------------------------------------------------------*/
/*----------------------------------------------------------------------------*/
double BasicCoderEncode(BasicCoder *coder, ArithEncoder *encoder,
int symbol, Boolean update)
{
double bits;
bits = ContextGetCost(coder->context, symbol);
if (encoder!=NULL){
encoder->range = encoder->high - encoder->low + 1;
ArithEncoderEncode(encoder, ContextGetCumul(coder->context, symbol-1),
ContextGetCumul(coder->context, symbol), coder->context->TotalFreq);
}
if (update){
ContextUpdate(coder->context, 1, symbol);
}
return bits;
}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -