📄 paq8o8.cpp
字号:
ns[state*4]=ns0;
ns[state*4+1]=ns1;
ns[state*4+2]=x;
ns[state*4+3]=y;
}
else if (t[x][y][1]) {
next_state(x0, y0, 0);
next_state(x1, y1, 1);
ns[state*4]=ns0=t[x0][y0][0];
ns[state*4+1]=ns1=t[x1][y1][0]+(t[x1][y1][1]>1);
ns[state*4+2]=x;
ns[state*4+3]=y;
}
// uncomment to print table above
// printf("{%3d,%3d,%2d,%2d},", ns[state*4], ns[state*4+1],
// ns[state*4+2], ns[state*4+3]);
// if (state%4==3) printf(" // %d-%d\n ", state-3, state);
assert(state>=0 && state<256);
assert(t[x][y][1]>0);
assert(t[x][y][0]<=state);
assert(t[x][y][0]+t[x][y][1]>state);
assert(t[x][y][1]<=6);
assert(t[x0][y0][1]>0);
assert(t[x1][y1][1]>0);
assert(ns0-t[x0][y0][0]<t[x0][y0][1]);
assert(ns0-t[x0][y0][0]>=0);
assert(ns1-t[x1][y1][0]<t[x1][y1][1]);
assert(ns1-t[x1][y1][0]>=0);
++state;
}
}
}
// printf("%d states\n", state); exit(0); // uncomment to print table above
}
#endif
///////////////////////////// Squash //////////////////////////////
// return p = 1/(1 + exp(-d)), d scaled by 8 bits, p scaled by 12 bits
int squash(int d) {
static const int t[33]={
1,2,3,6,10,16,27,45,73,120,194,310,488,747,1101,
1546,2047,2549,2994,3348,3607,3785,3901,3975,4022,
4050,4068,4079,4085,4089,4092,4093,4094};
if (d>2047) return 4095;
if (d<-2047) return 0;
int w=d&127;
d=(d>>7)+16;
return (t[d]*(128-w)+t[(d+1)]*w+64) >> 7;
}
//////////////////////////// Stretch ///////////////////////////////
// Inverse of squash. d = ln(p/(1-p)), d scaled by 8 bits, p by 12 bits.
// d has range -2047 to 2047 representing -8 to 8. p has range 0 to 4095.
class Stretch {
Array<short> t;
public:
Stretch();
int operator()(int p) const {
assert(p>=0 && p<4096);
return t[p];
}
} stretch;
Stretch::Stretch(): t(4096) {
int pi=0;
for (int x=-2047; x<=2047; ++x) { // invert squash()
int i=squash(x);
for (int j=pi; j<=i; ++j)
t[j]=x;
pi=i+1;
}
t[4095]=2047;
}
//////////////////////////// Mixer /////////////////////////////
// Mixer m(N, M, S=1, w=0) combines models using M neural networks with
// N inputs each, of which up to S may be selected. If S > 1 then
// the outputs of these neural networks are combined using another
// neural network (with parameters S, 1, 1). If S = 1 then the
// output is direct. The weights are initially w (+-32K).
// It is used as follows:
// m.update() trains the network where the expected output is the
// last bit (in the global variable y).
// m.add(stretch(p)) inputs prediction from one of N models. The
// prediction should be positive to predict a 1 bit, negative for 0,
// nominally +-256 to +-2K. The maximum allowed value is +-32K but
// using such large values may cause overflow if N is large.
// m.set(cxt, range) selects cxt as one of 'range' neural networks to
// use. 0 <= cxt < range. Should be called up to S times such
// that the total of the ranges is <= M.
// m.p() returns the output prediction that the next bit is 1 as a
// 12 bit number (0 to 4095).
// dot_product returns dot product t*w of n elements. n is rounded
// up to a multiple of 8. Result is scaled down by 8 bits.
#ifdef NOASM // no assembly language
int dot_product(short *t, short *w, int n) {
int sum=0;
n=(n+7)&-8;
for (int i=0; i<n; i+=2)
sum+=(t[i]*w[i]+t[i+1]*w[i+1]) >> 8;
return sum;
}
#else // The NASM version uses MMX and is about 8 times faster.
extern "C" int dot_product(short *t, short *w, int n); // in NASM
#endif
// Train neural network weights w[n] given inputs t[n] and err.
// w[i] += t[i]*err, i=0..n-1. t, w, err are signed 16 bits (+- 32K).
// err is scaled 16 bits (representing +- 1/2). w[i] is clamped to +- 32K
// and rounded. n is rounded up to a multiple of 8.
#ifdef NOASM
void train(short *t, short *w, int n, int err) {
n=(n+7)&-8;
for (int i=0; i<n; ++i) {
int wt=w[i]+((t[i]*err*2>>16)+1>>1);
if (wt<-32768) wt=-32768;
if (wt>32767) wt=32767;
w[i]=wt;
}
}
#else
extern "C" void train(short *t, short *w, int n, int err); // in NASM
#endif
class Mixer {
const int N, M, S; // max inputs, max contexts, max context sets
Array<short, 16> tx; // N inputs from add()
Array<short, 16> wx; // N*M weights
Array<int> cxt; // S contexts
int ncxt; // number of contexts (0 to S)
int base; // offset of next context
int nx; // Number of inputs in tx, 0 to N
Array<int> pr; // last result (scaled 12 bits)
Mixer* mp; // points to a Mixer to combine results
public:
Mixer(int n, int m, int s=1, int w=0);
// Adjust weights to minimize coding cost of last prediction
void update() {
for (int i=0; i<ncxt; ++i) {
int err=((y<<12)-pr[i])*7;
assert(err>=-32768 && err<32768);
if (err) train(&tx[0], &wx[cxt[i]*N], nx, err);
}
nx=base=ncxt=0;
}
// Input x (call up to N times)
void add(int x) {
assert(nx<N);
tx[nx++]=x;
}
// Set a context (call S times, sum of ranges <= M)
void set(int cx, int range) {
assert(range>=0);
assert(ncxt<S);
assert(cx>=0);
assert(base+cx<M);
cxt[ncxt++]=base+cx;
base+=range;
}
// predict next bit
int p() {
while (nx&7) tx[nx++]=0; // pad
if (mp) { // combine outputs
mp->update();
for (int i=0; i<ncxt; ++i) {
pr[i]=squash(dot_product(&tx[0], &wx[cxt[i]*N], nx)>>5);
mp->add(stretch(pr[i]));
}
mp->set(0, 1);
return mp->p();
}
else { // S=1 context
return pr[0]=squash(dot_product(&tx[0], &wx[0], nx)>>8);
}
}
~Mixer();
};
Mixer::~Mixer() {
delete mp;
}
Mixer::Mixer(int n, int m, int s, int w):
N((n+7)&-8), M(m), S(s), tx(N), wx(N*M),
cxt(S), ncxt(0), base(0), nx(0), pr(S), mp(0) {
assert(n>0 && N>0 && (N&7)==0 && M>0);
int i;
for ( i=0; i<S; ++i)
pr[i]=2048;
for ( i=0; i<N*M; ++i)
wx[i]=w;
if (S>1) mp=new Mixer(S, 1, 1, 0x7fff);
}
//////////////////////////// APM1 //////////////////////////////
// APM1 maps a probability and a context into a new probability
// that bit y will next be 1. After each guess it updates
// its state to improve future guesses. Methods:
//
// APM1 a(N) creates with N contexts, uses 66*N bytes memory.
// a.p(pr, cx, rate=7) returned adjusted probability in context cx (0 to
// N-1). rate determines the learning rate (smaller = faster, default 7).
// Probabilities are scaled 12 bits (0-4095).
class APM1 {
int index; // last p, context
const int N; // number of contexts
Array<U16> t; // [N][33]: p, context -> p
public:
APM1(int n);
int p(int pr=2048, int cxt=0, int rate=7) {
assert(pr>=0 && pr<4096 && cxt>=0 && cxt<N && rate>0 && rate<32);
pr=stretch(pr);
int g=(y<<16)+(y<<rate)-y-y;
t[index] += g-t[index] >> rate;
t[index+1] += g-t[index+1] >> rate;
const int w=pr&127; // interpolation weight (33 points)
index=(pr+2048>>7)+cxt*33;
return t[index]*(128-w)+t[index+1]*w >> 11;
}
};
// maps p, cxt -> p initially
APM1::APM1(int n): index(0), N(n), t(n*33) {
for (int i=0; i<N; ++i)
for (int j=0; j<33; ++j)
t[i*33+j] = i==0 ? squash((j-16)*128)*16 : t[j];
}
//////////////////////////// StateMap, APM //////////////////////////
// A StateMap maps a context to a probability. Methods:
//
// Statemap sm(n) creates a StateMap with n contexts using 4*n bytes memory.
// sm.p(y, cx, limit) converts state cx (0..n-1) to a probability (0..4095).
// that the next y=1, updating the previous prediction with y (0..1).
// limit (1..1023, default 1023) is the maximum count for computing a
// prediction. Larger values are better for stationary sources.
static int dt[1024]; // i -> 16K/(i+3)
class StateMap {
protected:
const int N; // Number of contexts
int cxt; // Context of last prediction
Array<U32> t; // cxt -> prediction in high 22 bits, count in low 10 bits
inline void update(int limit) {
assert(cxt>=0 && cxt<N);
U32 *p=&t[cxt], p0=p[0];
int n=p0&1023, pr=p0>>10; // count, prediction
if (n<limit) ++p0;
else p0=p0&0xfffffc00|limit;;
p0+=(((y<<22)-pr)>>3)*dt[n]&0xfffffc00;
p[0]=p0;
}
public:
StateMap(int n=256);
// update bit y (0..1), predict next bit in context cx
int p(int cx, int limit=1023) {
assert(cx>=0 && cx<N);
assert(limit>0 && limit<1024);
update(limit);
return t[cxt=cx]>>20;
}
};
StateMap::StateMap(int n): N(n), cxt(0), t(n) {
for (int i=0; i<N; ++i)
t[i]=1<<31;
}
// An APM maps a probability and a context to a new probability. Methods:
//
// APM a(n) creates with n contexts using 96*n bytes memory.
// a.pp(y, pr, cx, limit) updates and returns a new probability (0..4095)
// like with StateMap. pr (0..4095) is considered part of the context.
// The output is computed by interpolating pr into 24 ranges nonlinearly
// with smaller ranges near the ends. The initial output is pr.
// y=(0..1) is the last bit. cx=(0..n-1) is the other context.
// limit=(0..1023) defaults to 255.
class APM: public StateMap {
public:
APM(int n);
int p(int pr, int cx, int limit=255) {
// assert(y>>1==0);
assert(pr>=0 && pr<4096);
assert(cx>=0 && cx<N/24);
assert(limit>0 && limit<1024);
update(limit);
pr=(stretch(pr)+2048)*23;
int wt=pr&0xfff; // interpolation weight of next element
cx=cx*24+(pr>>12);
assert(cx>=0 && cx<N-1);
cxt=cx+(wt>>11);
pr=(t[cx]>>13)*(0x1000-wt)+(t[cx+1]>>13)*wt>>19;
return pr;
}
};
APM::APM(int n): StateMap(n*24) {
for (int i=0; i<N; ++i) {
int p=((i%24*2+1)*4096)/48-2048;
t[i]=(U32(squash(p))<<20)+6;
}
}
//////////////////////////// hash //////////////////////////////
// Hash 2-5 ints.
inline U32 hash(U32 a, U32 b, U32 c=0xffffffff, U32 d=0xffffffff,
U32 e=0xffffffff) {
U32 h=a*200002979u+b*30005491u+c*50004239u+d*70004807u+e*110002499u;
return h^h>>9^a>>2^b>>3^c>>4^d>>5^e>>6;
}
///////////////////////////// BH ////////////////////////////////
// A BH maps a 32 bit hash to an array of B bytes (checksum and B-2 values)
//
// BH bh(N); creates N element table with B bytes each.
// N must be a power of 2. The first byte of each element is
// reserved for a checksum to detect collisions. The remaining
// B-1 bytes are values, prioritized by the first value. This
// byte is 0 to mark an unused element.
//
// bh[i] returns a pointer to the i'th element, such that
// bh[i][0] is a checksum of i, bh[i][1] is the priority, and
// bh[i][2..B-1] are other values (0-255).
// The low lg(n) bits as an index into the table.
// If a collision is detected, up to M nearby locations in the same
// cache line are tested and the first matching checksum or
// empty element is returned.
// If no match or empty element is found, then the lowest priority
// element is replaced.
// 2 byte checksum with LRU replacement (except last 2 by priority)
template <int B> class BH {
enum {M=8}; // search limit
Array<U8, 64> t; // elements
U32 n; // size-1
public:
BH(int i): t(i*B), n(i-1) {
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -