📄 uda.h
字号:
////////////////////////////////////////////////////////////////////////
// UDA Core 0.301 (2006.12.19) Author:dwing
////////////////////////////////////////////////////////////////////////
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
typedef unsigned char U8;
typedef unsigned short U16;
typedef unsigned int U32;
////////////////////////////////////////////////////////////////////////
// 0x100000/0x200000: 96/179 MB
const U32 MEM =0x200000; // 0x10000 ~ 0x2000000
static U32 type =0; // 0=Default 1=EXE
static U32 y =0; // Last bit,0 or 1,set by encoder
static U32 c0 =1; // Last 0-7 bits of the partial byte with lead 1 bit
static U32 c1=0,c2=0,c3=0;//Last 1,2,3 byte.
static U32 c4 =0; // Last 4 whole bytes,packed. Last byte is bits 0-7.
static U32 bpos =0; // bits in c0 (0 to 7)
static U32 pos =0; // Number of input bytes in buf (not wrapped)
////////////////////////////////////////////////////////////////////////
template<class T> class Array
{
T *data;
U8 *p;
const U32 SIZE;
public:
explicit Array(U32 i,U32 align=0);
~Array() {free(p);}
T& operator[](U32 i) {return data[i];}
const T& operator[](U32 i)const {return data[i];}
U32 size() const {return SIZE;}
};
//static U32 memsize=0;
template<class T> Array<T>::Array(U32 i,U32 align):SIZE(i)
{
// memsize+=align+SIZE*sizeof(T);printf("%u ",memsize);
if(!(p=(U8*)calloc(align+SIZE*sizeof(T),1)))
{
printf("\nERROR: Not enough memory\n");
exit(0);
}
data=(T*)(align?(p+align)-((U32)p&(align-1)):p);
}
////////////////////////////////////////////////////////////////////////
class Buf
{
Array<U8> b;
public:
Buf(U32 i):b(i) {}
U32 size() const {return b.size();}
U8& operator[](U32 i) {return b[i&(size()-1)];}
U32 operator()(U32 i) const {return (U32)b[(pos-i)&(size()-1)];}
}buf(MEM*8);
////////////////////////////////////////////////////////////////////////
class Ilog // ilog(x) = round(log2(x)*16),0<=x<65536
{
Array<U8> t;
public:
Ilog():t(65536)
{
U32 x=14155776;
for(U32 i=2;i<65536;++i)
t[i]=(x+=774541002/(i*2-1))>>24; // 2^29/ln 2
}
U32 operator()(U16 x) const {return t[x];}
}ilog;
////////////////////////////////////////////////////////////////////////
U32 squash(int d) // return p = 1/(1+exp(-d)),d scaled 8b,p scaled 12b
{
static const U32 t[33]=
{ 1, 2, 3, 6, 10, 16, 27, 45,
73, 120, 194, 310, 488, 747,1101,1546,
2047,2549,2994,3348,3607,3785,3901,3975,
4022,4050,4068,4079,4085,4089,4092,4093,4094
};
if(d> 2047) return 4095;
if(d<-2047) return 0;
const U32 w=d&127;
d=(d>>7)+16;
return (t[d]*(128-w)+t[(d+1)]*w+64) >> 7;
}
////////////////////////////////////////////////////////////////////////
// Inverse of squash. d = ln(p/(1-p)),d scaled by 8 bits,p by 12 bits.
class Stretch
{
Array<int> t;
public:
Stretch():t(4096)
{
U32 j=0;
for(int x=-2047;x<=2047;++x) // invert squash()
for(U32 i=squash(x);j<=i;++j) t[j]=x;
t[4095]=2047;
}
int operator()(U32 p) const {return t[p];}
}stretch;
////////////////////////////////////////////////////////////////////////
// Mixer m(N,M,S=1,w=0) combines models using M neural networks with
// N inputs each,of which up to S may be selected. If S > 1 then
// the outputs of these neural networks are combined using another
// neural network (with parameters S,1,1). If S = 1 then the
// output is direct. The weights are initially w (+-32K).
// It is used as follows:
// m.update() trains the network where the expected output is the
// last bit (in the global variable y).
// m.add(stretch(p)) inputs prediction from one of N models. The
// prediction should be positive to predict a 1 bit,negative for 0,
// nominally +-256 to +-2K. The maximum allowed value is +-32K but
// using such large values may cause overflow ifN is large.
// m.set(cxt,range) selects cxt as one of 'range' neural networks to
// use. 0 <= cxt < range. Should be called up to S times such
// that the total of the ranges is <= M.
// m.p() returns the output prediction that the next bit is 1 as a
// 12 bit number (0 to 4095).
// dot_product returns dot product t*w of n elements. n is rounded
// up to a multiple of 8. Result is scaled down by 8 bits.
//int dot_product(short *t,short *w,U32 n)
//{
// int sum=0;
// n=(n+7)&-8;
// for(U32 i=0;i<n;i+=2)
// sum+=(t[i]*w[i]+t[i+1]*w[i+1]) >> 8;
// return sum;
//}
__declspec(naked) int __stdcall dot_product(short *t,short *w,U32 n)
{__asm{
mov ecx,[esp+12] // n
dec ecx // n rounding up
mov edx,[esp+ 8] // w
and ecx,-8
js done_
mov eax,[esp+ 4] // t
pxor mm0,mm0 // sum = 0
loop_: movq mm1,[eax+ecx*2 ]// put halves of vector product in mm0
pmaddwd mm1,[edx+ecx*2 ]
movq mm2,[eax+ecx*2+8]
pmaddwd mm2,[edx+ecx*2+8]
psrad mm1,8
psrad mm2,8
paddd mm0,mm1
paddd mm0,mm2
sub ecx,8 // each loop sums 4 products
jns loop_
movq mm1,mm0 // add 2 halves of mm0 and return in eax
psrlq mm1,32
paddd mm0,mm1
movd eax,mm0
emms
done_: ret 4*3
}}
// Train neural network weights w[n] given inputs t[n] and err.
// w[i] += t[i]*err,i=0..n-1. t,w,err are signed 16 bits (+- 32K).
// err is scaled 16 bits (representing +- 1/2).w[i] is clamped to +- 32K
// and rounded. n is rounded up to a multiple of 8.
//void train(short *t,short *w,U32 n,int err)
//{
// n=(n+7)&-8;
// for(U32 i=0;i<n;++i)
// {
// int wt=w[i]+(((t[i]*err*2>>16)+1)>>1);
// if(wt<-32768) wt=-32768;
// if(wt> 32767) wt= 32767;
// w[i]=wt;
// }
//}
__declspec(naked) void __stdcall train(short *t,short *w,U32 n,U32 err)
{__asm{ mov eax,[esp+16] // err
pcmpeqd mm1,mm1 // 4 copies of 1 in mm1
movd mm0,eax
psrlw mm1,15
punpcklwd mm0,mm0 // put 2 copies of err in mm0
mov ecx,[esp+12] // n
punpcklwd mm0,mm0 // put 4 copies of err in mm0
mov eax,[esp+ 4] // t
dec ecx // n/8 rounding up
mov edx,[esp+ 8] // w
and ecx,-8
js done_
loop_: movq mm2,[edx+ecx*2 ] // w[i]
movq mm3,[eax+ecx*2 ] // t[i]
movq mm4,[edx+ecx*2+8] // w[i]
movq mm5,[eax+ecx*2+8] // t[i]
paddsw mm3,mm3
paddsw mm5,mm5
pmulhw mm3,mm0
pmulhw mm5,mm0
paddsw mm3,mm1
paddsw mm5,mm1
psraw mm3,1
psraw mm5,1
paddsw mm2,mm3
paddsw mm4,mm5
movq [edx+ecx*2 ],mm2
movq [edx+ecx*2+8],mm4
sub ecx,8 // each iteration adjusts 8 weights
jns loop_
emms
done_: ret 4*4
}}
////////////////////////////////////////////////////////////////////////
class Mixer
{
const U32 N,S; // max inputs,max context sets
Array<short> tx; // N inputs from add()
Array<short> wx; // N*M weights
Array<U32> cxt; // S contexts
U32 ncxt; // number of contexts (0 to S)
U32 base; // offset of next context
U32 nx; // Number of inputs in tx,0 to N
Array<U32> pr; // last result (scaled 12 bits)
Mixer *mp; // points to a Mixer to combine results
public:
Mixer(U32 n,U32 k,U32 s=1,int w=0):N((n+7)&-8),S(s),tx(N,16),
wx(N*k,16),cxt(S),ncxt(0),base(0),nx(0),pr(S),mp(0)
{
U32 i;
for(i=0;i<S ;++i) pr[i]=2048;
for(i=0;i<N*k;++i) wx[i]=w;
if(S>1) mp=new Mixer(S,1,1,0x7fff);
}
~Mixer() {if(S>1) delete mp;}
void add(int x){tx[nx++]=x;}// Input x (call up to N times)
void set(U32 cx,U32 range)
{ // Set a context (call S times,sum of ranges <= M)
cxt[ncxt++]=base+cx;
base+=range;
}
// Adjust weights to minimize coding cost of last prediction
void update()
{
for(U32 i=ncxt;i--;)
train(&tx[0],&wx[cxt[i]*N],nx,((y<<12)-pr[i])*7);
nx=base=ncxt=0;
}
U32 p() // predict next bit
{
while(nx&7) tx[nx++]=0; // pad
if(mp) // combine outputs
{
mp->update();
for(U32 i=ncxt;i--;)
{
pr[i]=squash(dot_product(&tx[0],&wx[cxt[i]*N],nx)>>5);
mp->add(stretch(pr[i]));
}
mp->set(0,1);
return mp->p();
}
return pr[0]=squash(dot_product(&tx[0],&wx[0],nx)>>8);
}
}m(512,(16+256*3+512*13),4,128);
////////////////////////////////////////////////////////////////////////
// APM maps a probability and a context into a new probability
// that bit y will next be 1. After each guess it updates
// its state to improve future guesses. Methods:
// APM a(N) creates with N contexts,uses 66*N bytes memory.
// a.p(pr,cx,rate=8) returned adjusted probability in context cx (0 to
// N-1). rate determines the learning rate (smaller=faster,default 8).
// Probabilities are scaled 16 bits (0-65535).
class APM
{
U32 index; // last p,context
const U32 N; // number of contexts
Array<U16> t; // [N][33]: p,context -> p
public:
APM(U32 n):index(0),N(n),t(n*33) // maps p,cxt -> p initially
{
for(U32 i=0;i<N;++i)
for(U32 j=0;j<33;++j)
t[i*33+j] = i==0 ? squash((j-16)*128)*16 : t[j];
}
U32 p(U32 pr=2048,U32 cxt=0,U32 rate=8)
{
pr=stretch(pr);
const U32 g=(y<<16)+(y<<rate)-y-y;
t[index ] += (g-t[index ])>>rate;
t[index+1] += (g-t[index+1])>>rate;
const U32 w=pr&127; // interpolation weight (33 points)
index=((pr+2048)>>7)+cxt*33;
return (t[index]*(128-w)+t[index+1]*w) >> 11;
}
};
////////////////////////////////////////////////////////////////////////
#define nex(state,sel) state_table[state][sel]
static const U8 state_table[256][4]={
{ 1, 2, 0, 0},{ 3, 5, 1, 0},{ 4, 6, 0, 1},{ 7, 10, 2, 0},//0-3
{ 8, 12, 1, 1},{ 9, 13, 1, 1},{ 11, 14, 0, 2},{ 15, 19, 3, 0},//4-7
{ 16, 23, 2, 1},{ 17, 24, 2, 1},{ 18, 25, 2, 1},{ 20, 27, 1, 2},//8-11
{ 21, 28, 1, 2},{ 22, 29, 1, 2},{ 26, 30, 0, 3},{ 31, 33, 4, 0},//12-15
{ 32, 35, 3, 1},{ 32, 35, 3, 1},{ 32, 35, 3, 1},{ 32, 35, 3, 1},//16-19
{ 34, 37, 2, 2},{ 34, 37, 2, 2},{ 34, 37, 2, 2},{ 34, 37, 2, 2},//20-23
{ 34, 37, 2, 2},{ 34, 37, 2, 2},{ 36, 39, 1, 3},{ 36, 39, 1, 3},//24-27
{ 36, 39, 1, 3},{ 36, 39, 1, 3},{ 38, 40, 0, 4},{ 41, 43, 5, 0},//28-31
{ 42, 45, 4, 1},{ 42, 45, 4, 1},{ 44, 47, 3, 2},{ 44, 47, 3, 2},//32-35
{ 46, 49, 2, 3},{ 46, 49, 2, 3},{ 48, 51, 1, 4},{ 48, 51, 1, 4},//36-39
{ 50, 52, 0, 5},{ 53, 43, 6, 0},{ 54, 57, 5, 1},{ 54, 57, 5, 1},//40-43
{ 56, 59, 4, 2},{ 56, 59, 4, 2},{ 58, 61, 3, 3},{ 58, 61, 3, 3},//44-47
{ 60, 63, 2, 4},{ 60, 63, 2, 4},{ 62, 65, 1, 5},{ 62, 65, 1, 5},//48-51
{ 50, 66, 0, 6},{ 67, 55, 7, 0},{ 68, 57, 6, 1},{ 68, 57, 6, 1},//52-55
{ 70, 73, 5, 2},{ 70, 73, 5, 2},{ 72, 75, 4, 3},{ 72, 75, 4, 3},//56-59
{ 74, 77, 3, 4},{ 74, 77, 3, 4},{ 76, 79, 2, 5},{ 76, 79, 2, 5},//60-63
{ 62, 81, 1, 6},{ 62, 81, 1, 6},{ 64, 82, 0, 7},{ 83, 69, 8, 0},//64-67
{ 84, 71, 7, 1},{ 84, 71, 7, 1},{ 86, 73, 6, 2},{ 86, 73, 6, 2},//68-71
{ 44, 59, 5, 3},{ 44, 59, 5, 3},{ 58, 61, 4, 4},{ 58, 61, 4, 4},//72-75
{ 60, 49, 3, 5},{ 60, 49, 3, 5},{ 76, 89, 2, 6},{ 76, 89, 2, 6},//76-79
{ 78, 91, 1, 7},{ 78, 91, 1, 7},{ 80, 92, 0, 8},{ 93, 69, 9, 0},//80-83
{ 94, 87, 8, 1},{ 94, 87, 8, 1},{ 96, 45, 7, 2},{ 96, 45, 7, 2},//84-87
{ 48, 99, 2, 7},{ 48, 99, 2, 7},{ 88,101, 1, 8},{ 88,101, 1, 8},//88-91
{ 80,102, 0, 9},{103, 69,10, 0},{104, 87, 9, 1},{104, 87, 9, 1},//92-95
{106, 57, 8, 2},{106, 57, 8, 2},{ 62,109, 2, 8},{ 62,109, 2, 8},//96-99
{ 88,111, 1, 9},{ 88,111, 1, 9},{ 80,112, 0,10},{113, 85,11, 0},//100-03
{114, 87,10, 1},{114, 87,10, 1},{116, 57, 9, 2},{116, 57, 9, 2},//104-07
{ 62,119, 2, 9},{ 62,119, 2, 9},{ 88,121, 1,10},{ 88,121, 1,10},//108-11
{ 90,122, 0,11},{123, 85,12, 0},{124, 97,11, 1},{124, 97,11, 1},//112-15
{126, 57,10, 2},{126, 57,10, 2},{ 62,129, 2,10},{ 62,129, 2,10},//116-19
{ 98,131, 1,11},{ 98,131, 1,11},{ 90,132, 0,12},{133, 85,13, 0},//120-23
{134, 97,12, 1},{134, 97,12, 1},{136, 57,11, 2},{136, 57,11, 2},//124-27
{ 62,139, 2,11},{ 62,139, 2,11},{ 98,141, 1,12},{ 98,141, 1,12},//128-31
{ 90,142, 0,13},{143, 95,14, 0},{144, 97,13, 1},{144, 97,13, 1},//132-35
{ 68, 57,12, 2},{ 68, 57,12, 2},{ 62, 81, 2,12},{ 62, 81, 2,12},//136-39
{ 98,147, 1,13},{ 98,147, 1,13},{100,148, 0,14},{149, 95,15, 0},//140-43
{150,107,14, 1},{150,107,14, 1},{108,151, 1,14},{108,151, 1,14},//144-47
{100,152, 0,15},{153, 95,16, 0},{154,107,15, 1},{108,155, 1,15},//148-51
{100,156, 0,16},{157, 95,17, 0},{158,107,16, 1},{108,159, 1,16},//152-55
{100,160, 0,17},{161,105,18, 0},{162,107,17, 1},{108,163, 1,17},//156-59
{110,164, 0,18},{165,105,19, 0},{166,117,18, 1},{118,167, 1,18},//160-63
{110,168, 0,19},{169,105,20, 0},{170,117,19, 1},{118,171, 1,19},//164-67
{110,172, 0,20},{173,105,21, 0},{174,117,20, 1},{118,175, 1,20},//168-71
{110,176, 0,21},{177,105,22, 0},{178,117,21, 1},{118,179, 1,21},//172-75
{110,180, 0,22},{181,115,23, 0},{182,117,22, 1},{118,183, 1,22},//176-79
{120,184, 0,23},{185,115,24, 0},{186,127,23, 1},{128,187, 1,23},//180-83
{120,188, 0,24},{189,115,25, 0},{190,127,24, 1},{128,191, 1,24},//184-87
{120,192, 0,25},{193,115,26, 0},{194,127,25, 1},{128,195, 1,25},//188-91
{120,196, 0,26},{197,115,27, 0},{198,127,26, 1},{128,199, 1,26},//192-95
{120,200, 0,27},{201,115,28, 0},{202,127,27, 1},{128,203, 1,27},//196-99
{120,204, 0,28},{205,115,29, 0},{206,127,28, 1},{128,207, 1,28},//200-03
{120,208, 0,29},{209,125,30, 0},{210,127,29, 1},{128,211, 1,29},//204-07
{130,212, 0,30},{213,125,31, 0},{214,137,30, 1},{138,215, 1,30},//208-11
{130,216, 0,31},{217,125,32, 0},{218,137,31, 1},{138,219, 1,31},//212-15
{130,220, 0,32},{221,125,33, 0},{222,137,32, 1},{138,223, 1,32},//216-19
{130,224, 0,33},{225,125,34, 0},{226,137,33, 1},{138,227, 1,33},//220-23
{130,228, 0,34},{229,125,35, 0},{230,137,34, 1},{138,231, 1,34},//224-27
{130,232, 0,35},{233,125,36, 0},{234,137,35, 1},{138,235, 1,35},//228-31
{130,236, 0,36},{237,125,37, 0},{238,137,36, 1},{138,239, 1,36},//232-35
{130,240, 0,37},{241,125,38, 0},{242,137,37, 1},{138,243, 1,37},//236-39
{130,244, 0,38},{245,135,39, 0},{246,137,38, 1},{138,247, 1,38},//240-43
{140,248, 0,39},{249,135,40, 0},{250, 69,39, 1},{ 80,251, 1,39},//244-47
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -