📄 rpc.c
字号:
#include <u.h>#include <libc.h>#include <venti.h>#include "session.h"struct { int version; char *s;} vtVersions[] = { VtVersion02, "02", 0, 0,};static char EBigString[] = "string too long";static char EBigPacket[] = "packet too long";static char ENullString[] = "missing string";static char EBadVersion[] = "bad format in version string";static Packet *vtRPC(VtSession *z, int op, Packet *p);VtSession *vtAlloc(void){ VtSession *z; z = vtMemAllocZ(sizeof(VtSession)); z->lk = vtLockAlloc();// z->inHash = vtSha1Alloc(); z->inLock = vtLockAlloc(); z->part = packetAlloc();// z->outHash = vtSha1Alloc(); z->outLock = vtLockAlloc(); z->fd = -1; z->uid = vtStrDup("anonymous"); z->sid = vtStrDup("anonymous"); return z;}voidvtReset(VtSession *z){ vtLock(z->lk); z->cstate = VtStateAlloc; if(z->fd >= 0){ vtFdClose(z->fd); z->fd = -1; } vtUnlock(z->lk);}intvtConnected(VtSession *z){ return z->cstate == VtStateConnected;}voidvtDisconnect(VtSession *z, int error){ Packet *p; uchar *b;vtDebug(z, "vtDisconnect\n"); vtLock(z->lk); if(z->cstate == VtStateConnected && !error && z->vtbl == nil) { /* clean shutdown */ p = packetAlloc(); b = packetHeader(p, 2); b[0] = VtQGoodbye; b[1] = 0; vtSendPacket(z, p); } if(z->fd >= 0) vtFdClose(z->fd); z->fd = -1; z->cstate = VtStateClosed; vtUnlock(z->lk);}voidvtClose(VtSession *z){ vtDisconnect(z, 0);}voidvtFree(VtSession *z){ if(z == nil) return; vtLockFree(z->lk); vtSha1Free(z->inHash); vtLockFree(z->inLock); packetFree(z->part); vtSha1Free(z->outHash); vtLockFree(z->outLock); vtMemFree(z->uid); vtMemFree(z->sid); vtMemFree(z->vtbl); memset(z, 0, sizeof(VtSession)); z->fd = -1; vtMemFree(z);}char *vtGetUid(VtSession *s){ return s->uid;}char *vtGetSid(VtSession *z){ return z->sid;}intvtSetDebug(VtSession *z, int debug){ int old; vtLock(z->lk); old = z->debug; z->debug = debug; vtUnlock(z->lk); return old;}intvtSetFd(VtSession *z, int fd){ vtLock(z->lk); if(z->cstate != VtStateAlloc) { vtSetError("bad state"); vtUnlock(z->lk); return 0; } if(z->fd >= 0) vtFdClose(z->fd); z->fd = fd; vtUnlock(z->lk); return 1;}intvtGetFd(VtSession *z){ return z->fd;}intvtSetCryptoStrength(VtSession *z, int c){ if(z->cstate != VtStateAlloc) { vtSetError("bad state"); return 0; } if(c != VtCryptoStrengthNone) { vtSetError("not supported yet"); return 0; } return 1;}intvtGetCryptoStrength(VtSession *s){ return s->cryptoStrength;}intvtSetCompression(VtSession *z, int fd){ vtLock(z->lk); if(z->cstate != VtStateAlloc) { vtSetError("bad state"); vtUnlock(z->lk); return 0; } z->fd = fd; vtUnlock(z->lk); return 1;}intvtGetCompression(VtSession *s){ return s->compression;}intvtGetCrypto(VtSession *s){ return s->crypto;}intvtGetCodec(VtSession *s){ return s->codec;}char *vtGetVersion(VtSession *z){ int v, i; v = z->version; if(v == 0) return "unknown"; for(i=0; vtVersions[i].version; i++) if(vtVersions[i].version == v) return vtVersions[i].s; assert(0); return 0;}/* hold z->inLock */static intvtVersionRead(VtSession *z, char *prefix, int *ret){ char c; char buf[VtMaxStringSize]; char *q, *p, *pp; int i; q = prefix; p = buf; for(;;) { if(p >= buf + sizeof(buf)) { vtSetError(EBadVersion); return 0; } if(!vtFdReadFully(z->fd, (uchar*)&c, 1)) return 0; if(z->inHash) vtSha1Update(z->inHash, (uchar*)&c, 1); if(c == '\n') { *p = 0; break; } if(c < ' ' || c > 0x7f || *q && c != *q) { vtSetError(EBadVersion); return 0; } *p++ = c; if(*q) q++; } vtDebug(z, "version string in: %s\n", buf); p = buf + strlen(prefix); for(;;) { for(pp=p; *pp && *pp != ':' && *pp != '-'; pp++) ; for(i=0; vtVersions[i].version; i++) { if(strlen(vtVersions[i].s) != pp-p) continue; if(memcmp(vtVersions[i].s, p, pp-p) == 0) { *ret = vtVersions[i].version; return 1; } } p = pp; if(*p != ':') return 0; p++; } }Packet*vtRecvPacket(VtSession *z){ uchar buf[10], *b; int n; Packet *p; int size, len; if(z->cstate != VtStateConnected) { vtSetError("session not connected"); return 0; } vtLock(z->inLock); p = z->part; /* get enough for head size */ size = packetSize(p); while(size < 2) { b = packetTrailer(p, MaxFragSize); assert(b != nil); n = vtFdRead(z->fd, b, MaxFragSize); if(n <= 0) goto Err; size += n; packetTrim(p, 0, size); } if(!packetConsume(p, buf, 2)) goto Err; len = (buf[0] << 8) | buf[1]; size -= 2; while(size < len) { n = len - size; if(n > MaxFragSize) n = MaxFragSize; b = packetTrailer(p, n); if(!vtFdReadFully(z->fd, b, n)) goto Err; size += n; } p = packetSplit(p, len); vtUnlock(z->inLock); return p;Err: vtUnlock(z->inLock); return nil; }intvtSendPacket(VtSession *z, Packet *p){ IOchunk ioc; int n; uchar buf[2]; /* add framing */ n = packetSize(p); if(n >= (1<<16)) { vtSetError(EBigPacket); packetFree(p); return 0; } buf[0] = n>>8; buf[1] = n; packetPrefix(p, buf, 2); for(;;) { n = packetFragments(p, &ioc, 1, 0); if(n == 0) break; if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) { packetFree(p); return 0; } packetConsume(p, nil, n); } packetFree(p); return 1;}intvtGetString(Packet *p, char **ret){ uchar buf[2]; int n; char *s; if(!packetConsume(p, buf, 2)) return 0; n = (buf[0]<<8) + buf[1]; if(n > VtMaxStringSize) { vtSetError(EBigString); return 0; } s = vtMemAlloc(n+1); if(!packetConsume(p, (uchar*)s, n)) { vtMemFree(s); return 0; } s[n] = 0; *ret = s; return 1;}intvtAddString(Packet *p, char *s){ uchar buf[2]; int n; if(s == nil) { vtSetError(ENullString); return 0; } n = strlen(s); if(n > VtMaxStringSize) { vtSetError(EBigString); return 0; } buf[0] = n>>8; buf[1] = n; packetAppend(p, buf, 2); packetAppend(p, (uchar*)s, n); return 1;}intvtConnect(VtSession *z, char *password){ char buf[VtMaxStringSize], *p, *ep, *prefix; int i; USED(password); vtLock(z->lk); if(z->cstate != VtStateAlloc) { vtSetError("bad session state"); vtUnlock(z->lk); return 0; } if(z->fd < 0){ vtSetError("%s", z->fderror); vtUnlock(z->lk); return 0; } /* be a little anal */ vtLock(z->inLock); vtLock(z->outLock); prefix = "venti-"; p = buf; ep = buf + sizeof(buf); p = seprint(p, ep, "%s", prefix); p += strlen(p); for(i=0; vtVersions[i].version; i++) { if(i != 0) *p++ = ':'; p = seprint(p, ep, "%s", vtVersions[i].s); } p = seprint(p, ep, "-libventi\n"); assert(p-buf < sizeof(buf)); if(z->outHash) vtSha1Update(z->outHash, (uchar*)buf, p-buf); if(!vtFdWrite(z->fd, (uchar*)buf, p-buf)) goto Err; vtDebug(z, "version string out: %s", buf); if(!vtVersionRead(z, prefix, &z->version)) goto Err; vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z)); vtUnlock(z->inLock); vtUnlock(z->outLock); z->cstate = VtStateConnected; vtUnlock(z->lk); if(z->vtbl) return 1; if(!vtHello(z)) goto Err; return 1; Err: if(z->fd >= 0) vtFdClose(z->fd); z->fd = -1; vtUnlock(z->inLock); vtUnlock(z->outLock); z->cstate = VtStateClosed; vtUnlock(z->lk); return 0; }
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -