📄 tftpd.cpp
字号:
/** 文件名: tftpd.cpp** 版本: 1.0** 描述: TFTP服务器** 完成日期: 2004-11-20*/#include <unistd.h>#include <fcntl.h>#include <sys/socket.h>#include <sys/stat.h>#include <sys/select.h>#include <sys/time.h>#include <arpa/inet.h>#include <signal.h>#include <errno.h>#include <memory.h>#include <stdio.h>#include <stdlib.h>#include <ctype.h>/** 函数声明*/static inline void exiterr(const char* reason = NULL, unsigned char status = 1);static char* strtolower(char* str);static inline char* initmsg(char* buf, unsigned short op, unsigned short index, const char* msg = NULL);static int openfile(const char* fname);static int readn(int fd, char* buf, int len);static int readnn(FILE*file, char* buf, int len);static unsigned char startserv(const sockaddr* pclinaddr, socklen_t clinlen, const char* fname, bool isbin);/** 全局数据声明和初始化*///操作码const unsigned short RRQ = 1, WRQ = 2, DATA = 3, ACK = 4, ERR = 5;//读写缓冲区const int BSIZE = 1024;char buf[BSIZE];//指向缓冲区数据段的指针char* const pdata = buf+4;//超时时间const timeval TMOUT = {5, 0};//超时或错误ACK的最大次数const int MAXROUND = 5;//数据段的大小const int DSIZE = 512;/** main 函数*/int main(int argc, char* argv[]){ const unsigned short PORT = 69; int servfd; sockaddr_in servaddr, clinaddr; socklen_t clinlen = sizeof(clinaddr); int nread; unsigned short op; char* fname, *mod; bool isbin; pid_t pid; //获取并转到TFTP根目录 switch(argc) { case 1: break; case 2: if(chdir(argv[1]) < 0) exiterr("chdir error"); break; default: puts("usage: tftpd [TFTP root]"); return 1; } //完成启动UDP服务器的一般性步骤 memset(&servaddr, 0, sizeof(servaddr)); servaddr.sin_family = AF_INET; servaddr.sin_port = htons(PORT); servaddr.sin_addr.s_addr = htonl(INADDR_ANY); if((servfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) exiterr("socket error"); if(bind(servfd, (sockaddr*)&servaddr, sizeof(servaddr)) < 0) exiterr("bind error"); //忽略子进程中止信号 signal(SIGCHLD, SIG_IGN); //下面转变自身为deamon进程 //确保自己不是组长, 并转到后台 if((pid = fork()) < 0) exiterr("fork error"); else if(pid > 0) exit(0); //生成新会晤期 setsid(); //忽略会晤期主席中止信号 signal(SIGHUP, SIG_IGN); //会晤期主席退出, 子进程成为孤儿进程 if((pid = fork()) < 0) exit(1); else if(pid > 0) exit(0); //重定向stdin,stdout和stderr到"/dev/null" int fd = open("/dev/null", O_RDWR); if(fd < 0) exit(1); for(int i = 0; i<3; ++i) dup2(fd, i); //循环等待请求数据报 while(true) { //读取数据报 nread = recvfrom(servfd, buf, BSIZE, 0, (sockaddr*)&clinaddr, &clinlen); //若出错则退出 if(nread<0) if(errno == EINTR) continue; else exit(1); //若数据报太短(nread<2+1+1+5+1)或 //不是以0结尾, 则继续 if(nread < 10 || buf[nread-1] != 0) continue; //获取操作码 memcpy(&op, buf, 2); op = ntohs(op); //获取指向文件名字段的指针 fname = buf+2; //获取指向传输模式字段的指针 mod = fname+strlen(fname)+1; //若数据报格式不对, 则继续 if(mod >= buf+nread-1 || strlen(mod)+strlen(fname)+4 != (unsigned)nread) continue; //得到传输模式, 若传输模式不能被识别则继续 if(strcmp(strtolower(mod), "netascii") == 0) isbin = false; else if(strcmp(strtolower(mod), "octet") == 0) isbin = true; else continue; //若操作码不是RRQ, 则发送错误报文, 并继续 if(op != RRQ) { initmsg(buf, ERR, 4, "invalid operation"); sendto(servfd, buf, strlen(pdata)+5, 0, (sockaddr*)&clinaddr, clinlen); continue; } //生成子进程, 处理客户请求 if((pid = fork()) == 0) { close(servfd); exit(startserv((sockaddr*)&clinaddr, clinlen, fname, isbin)); } else if(pid < 0) exit(1); } return 0;}/** 错误退出的包装函数** 参数:* const char* reason: 要显示的消息* char status: 退出码*/static inline void exiterr(const char* reason/* = NULL*/, unsigned char status/* = 1*/){ perror(reason); exit(status);}/** 将指定字符串转换为小写* * 参数:* char* str: 要转换的字符串** 返回值:* char*: 转换后的字符串*/static char* strtolower(char* s){ char* str = s; while(*str != 0) { *str = tolower(*str); ++str; } return s;}/** 初始化消息** 参数:* char* buf: 目标缓冲区* short op: 操作码* unsigned short index: 块编号或差错码* const char* msg: 要发送的差错消息** 返回值:* char*: buf*/static inline char* initmsg(char* buf, unsigned short op, unsigned short index, const char* msg/* = NULL*/){ op = htons(op); index = htons(index); //填充前4个字节 memcpy(buf, &op, 2); memcpy(buf+2, &index, 2); //填充消息字段 if(msg != NULL) strcpy(pdata, msg); return buf;}/** 检查路径的合法性, 合法则尝试打开文件* 如下路径都是非法的:* /foo/bar : 绝对路径* foo/../bar : 含有上级目录** 参数:* const char* fname: 文件名** 返回:* int: 成功则返回非负描述字, 出错返回-1或-2* -1表示文件未找到, -2表示非法请求*/static int openfile(const char* fname){ struct stat st; //若文件名长度为0或为绝对路径或指向上级目录, //则作为非法请求处理 if(strlen(fname) == 0 || fname[0] == '/' || strncmp(fname, "../", 3) == 0) return -2; //查找fname中是否有对上级目录的访问 char* p = strpbrk(fname, "/"); while(p != NULL) { //若fname中含有"/../"或以"/.."结尾, 则作为非法请求处理 if(strncmp(p, "/../", 4) == 0 || strcmp(p, "/..") == 0) return -2; p = strpbrk(p+1, "/"); } //获取文件信息, 失败则返回-1 if(stat(fname, &st) < 0) return -1; //若文件类型不是普通文件, 则作为非法请求 if(!S_ISREG(st.st_mode)) return -2; return open(fname, O_RDONLY);}/** 从fd读取最多len个字节到buf** 参数:* int fd: 描述字* char* buf 读缓冲区* int len: 要读的长度** 返回:* int: 成功返回实际读到的字节数, 失败返回-1*/static int readn(int fd, char* buf, int len){ int nread, nleft = len; char* pb = buf; while(nleft>0) { if((nread = read(fd, pb, nleft)) < 0) if(errno == EINTR) continue; else return -1; if(nread == 0) break; nleft -= nread; pb += nread; } return len-nleft;}/** 从file读取数据, 并把"\n"转化为"\r\n"后, 放到buf里** 参数:* FILE* file: 要读的文件流* char* buf 读缓冲区* int len: 缓冲区大小** 返回:* int: 返回实际读到的字节数*/static int readnn(FILE* file, char* buf, int len){ int slen, nleft = len; char* pb = buf; char c; static bool isLF = false;//标记上次在最后读到的是不是LF //如果上次在最后读到LF, //则写一个LF到缓冲区 if(isLF) { isLF = false; buf[0] = '\n'; --nleft; ++pb; } //读数据到缓冲区 while(nleft > 0) { //EOF则跳出 if(fgets(pb, nleft, file) == NULL) break; slen = strlen(pb); pb += slen; nleft -= slen; //读到"\n"则转换为"\r\n" if(slen > 0 && *(pb-1) == '\n') { //肯定有空间放下2个字节 memcpy(pb-1, "\r\n", 2); --nleft; ++pb; } //最后一个字节 if(nleft == 1) { //如果最后的字符是LF, 则写一个CR到缓冲区, //并打开isLF标记 if((c = fgetc(file)) == '\n') { *pb = '\r'; nleft = 0; isLF = true; } //如果最后的字符不是LF, 则直接写到到缓冲区, else if(c != EOF) { *pb = c; nleft = 0; } break; } } return len - nleft;}/** 启动子服务器** 参数:* const sockaddr* pclinaddr: 客户IP地址结构的指针* socklen_t clinlen: 客户IP地址的长度* const char* fname: 客户请求的文件* bool isbin: 标记二进制传输模式** 返回:* unsigned char: 成功则返回0, 出错返回1或2或3*/static unsigned char startserv(const sockaddr* pclinaddr, socklen_t clinlen, const char* fname, bool isbin){ int clinfd, fd; //UDP套接字, 文件描述字 FILE* file = NULL; //用于netascii模式的文件流 int nread, scount; unsigned short ack[2]; //存放客户应答 timeval tmout; //超时时间 fd_set fdset; FD_ZERO(&fdset); //连接客户程序 if((clinfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) exit(1); if(connect(clinfd, pclinaddr, clinlen) < 0) exit(1); //打开文件, 出错则发送错误报文, 并返回 fd = openfile(fname); if(fd == -1) { initmsg(buf, ERR, 1, "file not found"); write(clinfd, buf, strlen(pdata)+5); return 1; } else if(fd == -2) { initmsg(buf, ERR, 2, "invalid access"); write(clinfd, buf, strlen(pdata)+5); return 2; } //为netascii模式生成文件流, 失败则 //发送错误报文, 并返回 if((!isbin) && (file = fdopen(fd, "r")) == NULL) { initmsg(buf, ERR, 0, "can't create FILE"); write(clinfd, buf, strlen(pdata)+5); return 1; } //读取文件数据, 并发送给客户 for(unsigned short index = 1; ; ++index) { //根据不同的模式读取数据到缓冲区 //octet模式 if(isbin) { if((nread = readn(fd, pdata, DSIZE)) < 0) { initmsg(buf, ERR, 0, "read error"); write(clinfd, buf, strlen(pdata)+5); return 3; } } //netascii模式 else nread = readnn(file, pdata, DSIZE); //发送一段数据给客户,直到成功跳出或 //超过最大重发次数退出或错误退出 for(int i = 0; i <= MAXROUND; ++i) { //超过最大重发次数则退出 if(i == MAXROUND) return 3; //发送数据, 错误则退出 if(write(clinfd, initmsg(buf, DATA, index), nread+4) < nread+4) return 3; //准备进行select FD_SET(clinfd, &fdset); tmout = TMOUT; scount = select(clinfd+1, &fdset, NULL, NULL, &tmout); //select错误 if(scount < 0) { if(errno == EINTR) { --i; continue; } else return 3; } //select超时 else if(scount == 0) continue; //此时可以确定clinfd已就绪(不考虑套接口错误) else { int n = read(clinfd, ack, sizeof(ack)); //接收到正确应答, 则跳出 if(n == sizeof(ack) && ntohs(ack[0]) == ACK && ntohs(ack[1]) == index) break; } } //到文件尾则跳出 if(nread < DSIZE) break; //编号溢出则重置 if(index == (unsigned short)0xffff) index = 0; } return 0;}
⌨️ 快捷键说明
复制代码
Ctrl + C
搜索代码
Ctrl + F
全屏模式
F11
切换主题
Ctrl + Shift + D
显示快捷键
?
增大字号
Ctrl + =
减小字号
Ctrl + -