ummsearch.c
来自「基于Blas CLapck的.用过的人知道是干啥的」· C语言 代码 · 共 1,470 行 · 第 1/3 页
C
1,470 行
}}MULTHEAD *CreateMultNode(int imult){ MULTHEAD *mh; mh = malloc(sizeof(MULTHEAD)); assert(mh); mh->imult = imult; mh->rn = NULL; mh->next = NULL; return(mh);}MULTHEAD *GetMultNode(int imult)/* * Finds, and if necessary, creates MULTHEAD with value imult, keeping it * in ascending order */{ MULTHEAD *mh, *mh0; if (imhead) { for (mh=imhead; mh; mh = mh->next) if (mh->imult >= imult) break; if (mh && imhead == mh && mh->imult != imult) { mh = CreateMultNode(imult); mh->next = imhead; imhead = mh; } else if (!mh || mh->imult != imult) { for (mh0=imhead; mh0->next != mh; mh0 = mh0->next); mh0->next = CreateMultNode(imult); mh0->next->next = mh; mh = mh0->next; } } else mh = imhead = CreateMultNode(imult); return(mh);}void KillMultNode(MULTHEAD *mh){ KillAllRoutNodes(mh->rn); free(mh);}void KillThisMultNode(MULTHEAD *mhD){ MULTHEAD *mh; if (mhD == NULL) return; if (mhD == imhead) imhead = imhead->next; else { for (mh=imhead; mh && mh->next != mhD; mh = mh->next); assert(mh); if (mh->next) mh->next = mh->next->next; } KillMultNode(mhD);}void KillAllMultNodes(){ MULTHEAD *mh; while(imhead) { mh = imhead->next; KillMultNode(imhead); imhead = mh; }}ROUTNODE *GetRoutNode(int imult, char *rout, int icase, double mflop)/* * Finds, and if necessary creates, the desired RoutNode */{ MULTHEAD *mh; ROUTNODE *rn; mh = GetMultNode(imult); for (rn = mh->rn; rn; rn = rn->next) if (rn->icase == icase) break; if (!rn) { rn = CreateRoutNode(rout, icase, mflop); if (mh->rn) { rn->next = mh->rn; mh->rn = rn; } else mh->rn = rn; } return(rn);}void PrintTable(FILE *fpout){ MULTHEAD *mh; ROUTNODE *rn; for (mh=imhead; mh; mh = mh->next) { fprintf(fpout, "%3d: ", mh->imult); for (rn=mh->rn; rn; rn = rn->next) fprintf(fpout, "%3d:%1d,%5.1f ", rn->icase, rn->fixed, rn->mflop); fprintf(fpout, "\n"); } fprintf(fpout, "\n");}MULTHEAD *BuildTable(char pre, enum CLEAN_WHICH which, int nb)/* * Builds table of possible cleanup codes, depending on which: * 0 : pMB * 1 : pNB * 2 : pKB */{ ROUTNODE *rn; int i, n, ID, NB[3]; int iin, io1, io2, iflag, muladd, lat, mu, nu, ku; char *MCC, *MMFLAGS; char rout[ROUTLEN], auth[AUTHLEN]; switch(which) { case CleanM: iin = 0; io1 = 1; io2 = 2; break; case CleanN: iin = 1; io1 = 0; io2 = 2; break; case CleanK: iin = 2; io1 = 0; io2 = 1; break; } n = NumUserCases(pre); for (i=0; i < n; i++) { rn = NULL; ID = GetUserCase(pre, -i, &iflag, NB, NB+1, NB+2, &muladd, &lat, &mu, &nu, &ku, rout, auth, &MCC, &MMFLAGS); if (ATL_MMNoClean(iflag)) continue; if (NB[io1] < 0 && NB[io1] != -nb) continue; if (NB[io2] < 0 && NB[io2] != -nb) continue; if (NB[io1] && (nb % NB[io1])) continue; if (NB[io2] && (nb % NB[io2])) continue; if (NB[iin] < 0) { if (-NB[iin] < nb) rn = GetRoutNode(-NB[iin], rout, ID, NOTIMED); } else if (NB[iin] == 0) rn = GetRoutNode(1, rout, ID, NOTIMED); else if (NB[iin] < nb) rn = GetRoutNode(NB[iin], rout, ID, NOTIMED); if (rn) rn->fixed = IsCaseFixed(pre, ID, which); } return(imhead);}int MakeMult(int nb, int mul)/* * takes nb, makes it a multiple of mul by reducing */{ return( (nb / mul) * mul );}int GetPNB(char pre, enum CLEAN_WHICH which, int icase, int NB, int imul, int *pNB)/* * Returns number of pNB that are multiple of imul (max of 3) to be timed; * pNB contains the values to try */{ int i=1, j; int iflag, NBs[3], muladd, lat, mu, nu, ku; char fnam[ROUTLEN], *MCC, *MMFLAGS; pNB[0] = pNB[1] = pNB[2] = 0; assert(GetUserCase(pre, icase, &iflag, NBs, NBs+1, NBs+2, &muladd, &lat, &mu, &nu, &ku, fnam, fnam, &MCC, &MMFLAGS)); if (NBs[which] < 0) pNB[0] = -NBs[which]; else { j = pNB[0] = MakeMult(NB-NB/8, imul); if (!j) pNB[0] = imul; j = MakeMult(NB/2, imul); if (j && j != pNB[0]) { pNB[1] = j; i = 2; j = NB/8; if (NB >= 32) j = Mmax(j, 16); j = pNB[2] = ((j+imul-1)/imul)*imul; if (j && j != pNB[1] && j != pNB[0]) i = 3; else pNB[2] = 0; } } return(i);}#define NO_RESULTS -88.7double GetRes(char *fnam){ FILE *fp; double mf, mflop[3]; int i; fp = fopen(fnam, "r"); if (!fp) return(NO_RESULTS); for (i=0; i != 3; i++) { if (fscanf(fp, " %lf", &mflop[i]) != 1) { fclose(fp); remove(fnam); return(NO_RESULTS); } } fclose(fp); mf = GetAvg(3, TOLERANCE, mflop); return(mf);}double ummcase0( char pre, /* type prefix */ int M, int N, int K, /* problem sizes to time */ int mb, int nb, int kb, /* 0: variable NB, else fixed cpp macro of NB */ int lda, int ldb, int ldc, /* leading dims */ int muladd, int lat, /* muladd and latency settings */ int mu, int nu, int ku, /* unrolling factors */ char *fnam, /* file name to compile */ char *MCC, char *MMFLAGS, /* NULL : use defaults, else comp to use */ char *outnam /* output name */){ char ln[512]; char ch; int i; double mf; if (!FileExists(outnam)) { if (pre == 'c' || pre == 'z') i = sprintf(ln, "make cmmucase mmrout=CASES/%s csC=2 ", fnam); else i = sprintf(ln, "make mmucase mmrout=CASES/%s ", fnam); if (MCC) { ch = (pre == 'c' || pre == 's') ? 'S' : 'D'; i += sprintf(ln+i, "%cMC=\"%s\" %cMCFLAGS=\"%s\" ", ch, MCC, ch, MMFLAGS); } i += sprintf(ln+i, "casnam=%s ", outnam); i += sprintf(ln+i, "pre=%c muladd=%d lat=%d M=%d N=%d K=%d mb=%d nb=%d kb=%d mu=%d nu=%d ku=%d lda=%d ldb=%d ldc=%d ", pre, muladd, lat, M, N, K, mb, nb, kb, mu, nu, ku, lda, ldb, ldc); i += sprintf(ln+i, "\n"); fprintf(stdout, "%s", ln); if (system(ln) != 0) return(-1.0); } mf = GetRes(outnam); if (mf == NO_RESULTS) mf = -1.0; return(mf);}int GetIflag(char pre, int icase){ int iflag, mb, nb, kb, muladd, lat, mu, nu, ku; char fnam[ROUTLEN], *MCC, *MMFLAGS; assert(GetUserCase(pre, icase, &iflag, &mb, &nb, &kb, &muladd, &lat, &mu, &nu, &ku, fnam, fnam, &MCC, &MMFLAGS)); return(iflag);}double GetCleanCase(char pre, enum CLEAN_WHICH which, int icase, int imul, int mb, int nb, int kb){ char cwh[3] = {'M', 'N', 'K'}; char outf[ROUTLEN], fnam[ROUTLEN], *MCC, *MMFLAGS; int ld=kb, NB[3], NB1[3], NBs[3], nb0; int iflag, mb1, nb1, kb1, muladd, lat, mu, nu, ku; assert(GetUserCase(pre, icase, &iflag, NB1, NB1+1, NB1+2, &muladd, &lat, &mu, &nu, &ku, fnam, outf, &MCC, &MMFLAGS)); if (ATL_MMNoClean(iflag)) return(-1.0); NBs[0] = mb; NBs[1] = nb; NBs[2] = kb; nb0 = kb; if (which == CleanK) { nb0 = nb; if (ATL_MMVarLda(iflag)) ld = 0; } NB[0] = NB[1] = NB[2] = nb0; if (NB1[which]) NB[which] = NBs[which]; else NB[which] = 0; sprintf(outf, "res/%cup%cB%d_%d_%dx%dx%d", pre, cwh[which], icase, imul, mb, nb, kb); return(ummcase0(pre, mb, nb, kb, NB[0], NB[1], NB[2], ld, ld, 0, muladd, lat, mu, nu, ku, fnam, MCC, MMFLAGS, outf));}double GetCleanCases0(char pre, enum CLEAN_WHICH which, int nb, int imul, int icase, int n, int *pNB){ int i, NB[3]; double mf0, mf=0.0; NB[0] = NB[1] = NB[2] = nb; for (i=0; i < n; i++) { NB[which] = pNB[i]; mf0 = GetCleanCase(pre, which, icase, imul, NB[0], NB[1], NB[2]); if (mf0 <= 0.0) return(-1.0); /* reject if it fails to run */ mf += mf0; } return(mf / n);}double GetCleanCases(char pre, enum CLEAN_WHICH which, int nb, int imul, int icase){ int n, pNB[3]; n = GetPNB(pre, which, icase, nb, imul, pNB); return(GetCleanCases0(pre, which, nb, imul, icase, n, pNB));}void TimeRouts(char pre, enum CLEAN_WHICH which, int nb, int imul, ROUTNODE *rn0){ ROUTNODE *rn; for (rn=rn0; rn; rn = rn->next) if (rn->mflop == NOTIMED) rn->mflop = GetCleanCases(pre, which, nb, imul, rn->icase);}void TimeTable(char pre, enum CLEAN_WHICH which, int nb)/* * Times table of cleanup codes, depending on which */{ MULTHEAD *mh; for (mh=imhead; mh; mh = mh->next) TimeRouts(pre, which, nb, mh->imult, mh->rn);}ROUTNODE *FindBestRout(ROUTNODE *rn0, double adv)/* * Excludes those of fixed size */{ ROUTNODE *rnB=NULL, *rn; double mfB=0.0, ad; for (rn=rn0; rn; rn = rn->next) { if (rn->fixed != 2) { ad = (rn->fixed == 0)*adv*rn->mflop; if (rn->mflop+ad > mfB) { rnB = rn; mfB = rn->mflop+ad; } } } return(rnB);}ROUTNODE *FindBestFixed2(ROUTNODE *rn0){ ROUTNODE *rnB=NULL, *rn; double mfB=0.0; for (rn=rn0; rn; rn = rn->next) { if (rn->fixed == 2 && rn->mflop > mfB) { rnB = rn; mfB = rn->mflop; } } return(rnB);}void ReduceRouts(char pre, enum CLEAN_WHICH which, MULTHEAD *mh, double adv, int nb)/* * reduces routs to best, giving adv advantage to non-fixed routs */{ double mf; ROUTNODE *rn, *rnN, *rnF, *rnNF=NULL; int n, NB[3]; rn = FindBestRout(mh->rn, adv); rnF = FindBestFixed2(mh->rn); if (rn || rnF) /* some case actually compiled and ran */ { if (rnF && rn) { n = GetPNB(pre, which, rnF->icase, nb, mh->imult, NB); mf = GetCleanCases0(pre, which, nb, mh->imult, rn->icase, n, NB); mf += adv*mf; if (mf < rnF->mflop) { rnNF = CreateRoutNode(rnF->rout, rnF->icase, rnF->mflop); rnNF->fixed = 2; } } else if (rnF) { rnNF = CreateRoutNode(rnF->rout, rnF->icase, rnF->mflop); rnNF->fixed = 2; KillAllRoutNodes(mh->rn); mh->rn = rnNF; } if (rn) { rnN = CreateRoutNode(rn->rout, rn->icase, rn->mflop); rnN->fixed = rn->fixed; KillAllRoutNodes(mh->rn); mh->rn = rnN; if (rnNF) mh->rn->next = rnNF; else mh->rn->next = NULL; } } else KillThisMultNode(mh);}void ReduceMults(char pre, enum CLEAN_WHICH which, double adv, int nb)/* * Finds mults that are multiples of each other, and takes best, * giving adv advantage to non-fixed routs, and .5 adv to lower imults */{ MULTHEAD *mh, *mh0, *mh1; double mf0, mf1; int imult; int NB[3]; NB[0] = NB[1] = NB[2] = nb; for (mh0=imhead; mh0; mh0 = mh0->next) { if (mh0->rn->fixed == 2) continue; for (mh1=mh0, mh=mh0->next; mh; mh = mh->next) { if (mh->imult % mh0->imult == 0) /* higher mult is mult of lower */ { imult = mh->imult; if (mh->rn->fixed == 2) { NB[which] = imult; mf1 = mh->rn->mflop; } else { NB[which] = imult <= 4 ? ((nb-imult)/imult)*imult : ((nb-1)/imult)*imult; mf1 = GetCleanCase(pre, which, mh->rn->icase, imult, NB[0], NB[1], NB[2]); } mf0 = GetCleanCase(pre, which, mh0->rn->icase, imult, NB[0], NB[1], NB[2]); if (((mh0->rn->fixed == 0)+0.5)*adv*mf0+mf0 > (mh->rn->fixed == 0)*adv*mf1+mf1) { mh1->next = mh->next; KillMultNode(mh); mh = mh1; } } mh1 = mh; } }}void ReduceTable(char pre, enum CLEAN_WHICH which, int nb)/*
⌨️ 快捷键说明
复制代码Ctrl + C
搜索代码Ctrl + F
全屏模式F11
增大字号Ctrl + =
减小字号Ctrl + -
显示快捷键?