📄 paq8f.cpp
字号:
// err is scaled 16 bits (representing +- 1/2). w[i] is clamped to +- 32K
// and rounded. n is rounded up to a multiple of 8.
//#ifdef NOASM
void train(short *t, short *w, int n, int err) {
n=(n+7)&-8;
for (int i=0; i<n; ++i) {
int wt=w[i]+((t[i]*err*2>>16)+1>>1);
if (wt<-32768) wt=-32768;
if (wt>32767) wt=32767;
w[i]=wt;
}
}
//#else
//extern "C" void train(short *t, short *w, int n, int err); // in NASM
//#endif
class Mixer {
const int N, M, S; // max inputs, max contexts, max context sets
Array<short, 16> tx; // N inputs from add()
Array<short, 16> wx; // N*M weights
Array<int> cxt; // S contexts
int ncxt; // number of contexts (0 to S)
int base; // offset of next context
int nx; // Number of inputs in tx, 0 to N
Array<int> pr; // last result (scaled 12 bits)
Mixer* mp; // points to a Mixer to combine results
public:
Mixer(int n, int m, int s=1, int w=0);
// Adjust weights to minimize coding cost of last prediction
void update() {
for (int i=0; i<ncxt; ++i) {
int err=((y<<12)-pr[i])*7;
assert(err>=-32768 && err<32768);
train(&tx[0], &wx[cxt[i]*N], nx, err);
}
nx=base=ncxt=0;
}
// Input x (call up to N times)
void add(int x) {
assert(nx<N);
tx[nx++]=x;
}
// Set a context (call S times, sum of ranges <= M)
void set(int cx, int range) {
assert(range>=0);
assert(ncxt<S);
assert(cx>=0);
assert(base+cx<M);
cxt[ncxt++]=base+cx;
base+=range;
}
// predict next bit
int Mixer::p() {
while (nx&7) tx[nx++]=0; // pad
if (mp) { // combine outputs
mp->update();
for (int i=0; i<ncxt; ++i) {
pr[i]=squash(dot_product(&tx[0], &wx[cxt[i]*N], nx)>>5);
mp->add(stretch(pr[i]));
}
mp->set(0, 1);
return mp->p();
}
else { // S=1 context
return pr[0]=squash(dot_product(&tx[0], &wx[0], nx)>>8);
}
}
~Mixer();
};
Mixer::~Mixer() {
delete mp;
}
Mixer::Mixer(int n, int m, int s, int w):
N((n+7)&-8), M(m), S(s), tx(N), wx(N*M),
cxt(S), ncxt(0), base(0), nx(0), pr(S), mp(0) {
assert(n>0 && N>0 && (N&7)==0 && M>0);
for (int i=0; i<S; ++i)
pr[i]=2048;
{
for (int i=0; i<N*M; ++i)
wx[i]=w;
}
if (S>1) mp=new Mixer(S, 1, 1, 0x7fff);
}
//////////////////////////// APM //////////////////////////////
// APM maps a probability and a context into a new probability
// that bit y will next be 1. After each guess it updates
// its state to improve future guesses. Methods:
//
// APM a(N) creates with N contexts, uses 66*N bytes memory.
// a.p(pr, cx, rate=8) returned adjusted probability in context cx (0 to
// N-1). rate determines the learning rate (smaller = faster, default 8).
// Probabilities are scaled 16 bits (0-65535).
class APM {
int index; // last p, context
const int N; // number of contexts
Array<U16> t; // [N][33]: p, context -> p
public:
APM(int n);
int p(int pr=2048, int cxt=0, int rate=8) {
assert(pr>=0 && pr<4096 && cxt>=0 && cxt<N && rate>0 && rate<32);
pr=stretch(pr);
int g=(y<<16)+(y<<rate)-y-y;
t[index] += g-t[index] >> rate;
t[index+1] += g-t[index+1] >> rate;
const int w=pr&127; // interpolation weight (33 points)
index=(pr+2048>>7)+cxt*33;
return t[index]*(128-w)+t[index+1]*w >> 11;
}
};
// maps p, cxt -> p initially
APM::APM(int n): index(0), N(n), t(n*33) {
for (int i=0; i<N; ++i)
for (int j=0; j<33; ++j)
t[i*33+j] = i==0 ? squash((j-16)*128)*16 : t[j];
}
//////////////////////////// StateMap //////////////////////////
// A StateMap maps a nonstationary counter state to a probability.
// After each mapping, the mapping is adjusted to improve future
// predictions. Methods:
//
// sm.p(cx) converts state cx (0-255) to a probability (0-4095).
// Counter state -> probability * 256
class StateMap {
protected:
int cxt; // context
Array<U16> t; // 256 states -> probability * 64K
public:
StateMap();
int p(int cx) {
assert(cx>=0 && cx<t.size());
t[cxt]+=(y<<16)-t[cxt]+128 >> 8;
return t[cxt=cx] >> 4;
}
};
StateMap::StateMap(): cxt(0), t(256) {
for (int i=0; i<256; ++i) {
int n0=nex(i,2);
int n1=nex(i,3);
if (n0==0) n1*=64;
if (n1==0) n0*=64;
t[i] = 65536*(n1+1)/(n0+n1+2);
}
}
//////////////////////////// hash //////////////////////////////
// Hash 2-5 ints.
inline U32 hash(U32 a, U32 b, U32 c=0xffffffff, U32 d=0xffffffff,
U32 e=0xffffffff) {
U32 h=a*200002979u+b*30005491u+c*50004239u+d*70004807u+e*110002499u;
return h^h>>9^a>>2^b>>3^c>>4^d>>5^e>>6;
}
///////////////////////////// BH ////////////////////////////////
// A BH maps a 32 bit hash to an array of B bytes (checksum and B-2 values)
//
// BH bh(N); creates N element table with B bytes each.
// N must be a power of 2. The first byte of each element is
// reserved for a checksum to detect collisions. The remaining
// B-1 bytes are values, prioritized by the first value. This
// byte is 0 to mark an unused element.
//
// bh[i] returns a pointer to the i'th element, such that
// bh[i][0] is a checksum of i, bh[i][1] is the priority, and
// bh[i][2..B-1] are other values (0-255).
// The low lg(n) bits as an index into the table.
// If a collision is detected, up to M nearby locations in the same
// cache line are tested and the first matching checksum or
// empty element is returned.
// If no match or empty element is found, then the lowest priority
// element is replaced.
// 2 byte checksum with LRU replacement (except last 2 by priority)
template <int B> class BH {
enum {M=8}; // search limit
Array<U8, 64> t; // elements
U32 n; // size-1
public:
BH(int i): t(i*B), n(i-1) {
assert(B>=2 && i>0 && (i&(i-1))==0); // size a power of 2?
}
U8* operator[](U32 i);
};
template <int B>
inline U8* BH<B>::operator[](U32 i) {
int chk=(i>>16^i)&0xffff;
i=i*M&n;
U8 *p;
U16 *cp;
int j;
for (j=0; j<M; ++j) {
p=&t[(i+j)*B];
cp=(U16*)p;
if (p[2]==0) *cp=chk;
if (*cp==chk) break; // found
}
if (j==0) return p+1; // front
static U8 tmp[B]; // element to move to front
if (j==M) {
--j;
memset(tmp, 0, B);
*(U16*)tmp=chk;
if (M>2 && t[(i+j)*B+2]>t[(i+j-1)*B+2]) --j;
}
else memcpy(tmp, cp, B);
memmove(&t[(i+1)*B], &t[i*B], j*B);
memcpy(&t[i*B], tmp, B);
return &t[i*B+1];
}
/////////////////////////// ContextMap /////////////////////////
//
// A ContextMap maps contexts to a bit histories and makes predictions
// to a Mixer. Methods common to all classes:
//
// ContextMap cm(M, C); creates using about M bytes of memory (a power
// of 2) for C contexts.
// cm.set(cx); sets the next context to cx, called up to C times
// cx is an arbitrary 32 bit value that identifies the context.
// It should be called before predicting the first bit of each byte.
// cm.mix(m) updates Mixer m with the next prediction. Returns 1
// if context cx is found, else 0. Then it extends all the contexts with
// global bit y. It should be called for every bit:
//
// if (bpos==0)
// for (int i=0; i<C; ++i) cm.set(cxt[i]);
// cm.mix(m);
//
// The different types are as follows:
//
// - RunContextMap. The bit history is a count of 0-255 consecutive
// zeros or ones. Uses 4 bytes per whole byte context. C=1.
// The context should be a hash.
// - SmallStationaryContextMap. 0 <= cx < M/512.
// The state is a 16-bit probability that is adjusted after each
// prediction. C=1.
// - ContextMap. For large contexts, C >= 1. Context need not be hashed.
// Predict to mixer m from bit history state s, using sm to map s to
// a probability.
inline int mix2(Mixer& m, int s, StateMap& sm) {
int p1=sm.p(s);
int n0=nex(s,2);
int n1=nex(s,3);
int st=stretch(p1)>>2;
m.add(st);
p1>>=4;
int p0=255-p1;
m.add(p1-p0);
m.add(st*(!n0-!n1));
m.add((p1&-!n0)-(p0&-!n1));
m.add((p1&-!n1)-(p0&-!n0));
return s>0;
}
// A RunContextMap maps a context into the next byte and a repeat
// count up to M. Size should be a power of 2. Memory usage is 3M/4.
class RunContextMap {
BH<4> t;
U8* cp;
public:
RunContextMap(int m): t(m/4) {cp=t[0]+1;}
void set(U32 cx) { // update count
if (cp[0]==0 || cp[1]!=buf(1)) cp[0]=1, cp[1]=buf(1);
else if (cp[0]<255) ++cp[0];
cp=t[cx]+1;
}
int p() { // predict next bit
if (cp[1]+256>>8-bpos==c0)
return ((cp[1]>>7-bpos&1)*2-1)*ilog(cp[0]+1)*8;
else
return 0;
}
int mix(Mixer& m) { // return run length
m.add(p());
return cp[0]!=0;
}
};
// Context is looked up directly. m=size is power of 2 in bytes.
// Context should be < m/512. High bits are discarded.
class SmallStationaryContextMap {
Array<U16> t;
int cxt;
U16 *cp;
public:
SmallStationaryContextMap(int m): t(m/2), cxt(0) {
assert((m/2&m/2-1)==0); // power of 2?
for (int i=0; i<t.size(); ++i)
t[i]=32768;
cp=&t[0];
}
void set(U32 cx) {
cxt=cx*256&t.size()-256;
}
void mix(Mixer& m, int rate=7) {
*cp += (y<<16)-*cp+(1<<rate-1) >> rate;
cp=&t[cxt+c0];
m.add(stretch(*cp>>4));
}
};
// Context map for large contexts. Most modeling uses this type of context
// map. It includes a built in RunContextMap to predict the last byte seen
// in the same context, and also bit-level contexts that map to a bit
// history state.
//
// Bit histories are stored in a hash table. The table is organized into
// 64-byte buckets alinged on cache page boundaries. Each bucket contains
// a hash chain of 7 elements, plus a 2 element queue (packed into 1 byte)
// of the last 2 elements accessed for LRU replacement. Each element has
// a 2 byte checksum for detecting collisions, and an array of 7 bit history
// states indexed by the last 0 to 2 bits of context. The buckets are indexed
// by a context ending after 0, 2, or 5 bits of the current byte. Thus, each
// byte modeled results in 3 main memory accesses per context, with all other
// accesses to cache.
//
// On bits 0, 2 and 5, the context is updated and a new bucket is selected.
// The most recently accessed element is tried first, by comparing the
// 16 bit checksum, then the 7 elements are searched linearly. If no match
// is found, then the element with the lowest priority among the 5 elements
// not in the LRU queue is replaced. After a replacement, the queue is
// emptied (so that consecutive misses favor a LFU replacement policy).
// In all cases, the found/replaced element is put in the front of the queue.
//
// The priority is the state number of the first element (the one with 0
// additional bits of context). The states are sorted by increasing n0+n1
// (number of bits seen), implementing a LFU replacement policy.
//
// When the context ends on a byte boundary (bit 0), only 3 of the 7 bit
// history states are used. The remaining 4 bytes implement a run model
// as follows: <count:7,d:1> <b1> <b2> <b3> where <b1> is the last byte
// seen, possibly repeated, and <b2> and <b3> are the two bytes seen
// before the first <b1>. <count:7,d:1> is a 7 bit count and a 1 bit
// flag. If d=0 then <count> = 1..127 is the number of repeats of <b1>
// and no other bytes have been seen, and <b2><b3> are not used.
// If <d> = 1 then the history is <b3>, <b2>, and <count> - 2 repeats
// of <b1>. In this case, <b3> is valid only if <count> >= 3 and
// <b2> is valid only if <count> >= 2.
//
// As an optimization, the last two hash elements of each byte (representing
// contexts with 2-7 bits) are not updated until a context is seen for
// a second time. This is indicated by <count,d> = <1,0>. After update,
// <count,d> is updated to <2,0> or <2,1>.
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -