⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 tftpd.cpp

📁 TFTP服务器
💻 CPP
字号:
/** 文件名: tftpd.cpp** 版本: 1.0** 描述: TFTP服务器** 完成日期: 2004-11-20*/#include <unistd.h>#include <fcntl.h>#include <sys/socket.h>#include <sys/stat.h>#include <sys/select.h>#include <sys/time.h>#include <arpa/inet.h>#include <signal.h>#include <errno.h>#include <memory.h>#include <stdio.h>#include <stdlib.h>#include <ctype.h>/** 函数声明*/static inline void exiterr(const char* reason = NULL, unsigned char status = 1);static char* strtolower(char* str);static inline char* initmsg(char* buf, unsigned short op, unsigned short index, const char* msg = NULL);static int openfile(const char* fname);static int readn(int fd, char* buf, int len);static int readnn(FILE*file, char* buf, int len);static unsigned char startserv(const sockaddr* pclinaddr, socklen_t clinlen, const char* fname, bool isbin);/** 全局数据声明和初始化*///操作码const unsigned short RRQ = 1, WRQ = 2, DATA = 3, ACK = 4, ERR = 5;//读写缓冲区const int BSIZE = 1024;char buf[BSIZE];//指向缓冲区数据段的指针char* const pdata = buf+4;//超时时间const timeval TMOUT = {5, 0};//超时或错误ACK的最大次数const int MAXROUND = 5;//数据段的大小const int DSIZE = 512;/** main 函数*/int main(int argc, char* argv[]){  const unsigned short PORT = 69;  int servfd;  sockaddr_in servaddr, clinaddr;  socklen_t clinlen = sizeof(clinaddr);  int nread;  unsigned short op;  char* fname, *mod;  bool isbin;  pid_t pid;  //获取并转到TFTP根目录  switch(argc)  {  case 1:    break;  case 2:    if(chdir(argv[1]) < 0)      exiterr("chdir error");    break;  default:    puts("usage: tftpd [TFTP root]");    return 1;  }  //完成启动UDP服务器的一般性步骤  memset(&servaddr, 0, sizeof(servaddr));  servaddr.sin_family = AF_INET;  servaddr.sin_port = htons(PORT);  servaddr.sin_addr.s_addr = htonl(INADDR_ANY);  if((servfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0)    exiterr("socket error");  if(bind(servfd, (sockaddr*)&servaddr, sizeof(servaddr)) < 0)    exiterr("bind error");  //忽略子进程中止信号  signal(SIGCHLD, SIG_IGN);  //下面转变自身为deamon进程  //确保自己不是组长, 并转到后台  if((pid = fork()) < 0)    exiterr("fork error");  else if(pid > 0)    exit(0);  //生成新会晤期  setsid();  //忽略会晤期主席中止信号  signal(SIGHUP, SIG_IGN);  //会晤期主席退出, 子进程成为孤儿进程  if((pid = fork()) < 0)    exit(1);  else if(pid > 0)    exit(0);  //重定向stdin,stdout和stderr到"/dev/null"  int fd = open("/dev/null", O_RDWR);  if(fd < 0)    exit(1);  for(int i = 0; i<3; ++i)    dup2(fd, i);  //循环等待请求数据报  while(true)  {    //读取数据报    nread = recvfrom(servfd, buf, BSIZE, 0, (sockaddr*)&clinaddr, &clinlen);    //若出错则退出    if(nread<0)      if(errno == EINTR)        continue;      else        exit(1);    //若数据报太短(nread<2+1+1+5+1)或    //不是以0结尾, 则继续    if(nread < 10 || buf[nread-1] != 0)      continue;    //获取操作码    memcpy(&op, buf, 2);    op = ntohs(op);    //获取指向文件名字段的指针    fname = buf+2;    //获取指向传输模式字段的指针    mod = fname+strlen(fname)+1;    //若数据报格式不对, 则继续    if(mod >= buf+nread-1 || strlen(mod)+strlen(fname)+4 != (unsigned)nread)      continue;    //得到传输模式, 若传输模式不能被识别则继续    if(strcmp(strtolower(mod), "netascii") == 0)      isbin = false;    else if(strcmp(strtolower(mod), "octet") == 0)      isbin = true;    else      continue;    //若操作码不是RRQ, 则发送错误报文, 并继续    if(op != RRQ)    {      initmsg(buf, ERR, 4, "invalid operation");      sendto(servfd, buf, strlen(pdata)+5, 0, (sockaddr*)&clinaddr, clinlen);      continue;    }    //生成子进程, 处理客户请求    if((pid = fork()) == 0)    {      close(servfd);      exit(startserv((sockaddr*)&clinaddr, clinlen, fname, isbin));    }    else if(pid < 0)      exit(1);  }  return 0;}/** 错误退出的包装函数** 参数:*     const char* reason: 要显示的消息*     char status: 退出码*/static inline void exiterr(const char* reason/* = NULL*/, unsigned char status/* = 1*/){  perror(reason);  exit(status);}/** 将指定字符串转换为小写* * 参数:*     char* str: 要转换的字符串** 返回值:*     char*: 转换后的字符串*/static char* strtolower(char* s){  char* str = s;  while(*str != 0)  {    *str = tolower(*str);    ++str;  }  return s;}/** 初始化消息** 参数:*     char* buf: 目标缓冲区*     short op: 操作码*     unsigned short index: 块编号或差错码*     const char* msg: 要发送的差错消息** 返回值:*     char*: buf*/static inline char* initmsg(char* buf, unsigned short op, unsigned short index, const char* msg/* = NULL*/){  op = htons(op);  index = htons(index);  //填充前4个字节  memcpy(buf, &op, 2);  memcpy(buf+2, &index, 2);  //填充消息字段  if(msg != NULL)    strcpy(pdata, msg);  return buf;}/** 检查路径的合法性, 合法则尝试打开文件* 如下路径都是非法的:*                  /foo/bar : 绝对路径*                  foo/../bar : 含有上级目录** 参数:*     const char* fname: 文件名** 返回:*    int: 成功则返回非负描述字, 出错返回-1或-2*        -1表示文件未找到, -2表示非法请求*/static int openfile(const char* fname){  struct stat st;  //若文件名长度为0或为绝对路径或指向上级目录,  //则作为非法请求处理  if(strlen(fname) == 0 || fname[0] == '/' || strncmp(fname, "../", 3) == 0)    return -2;  //查找fname中是否有对上级目录的访问  char* p = strpbrk(fname, "/");  while(p != NULL)  {    //若fname中含有"/../"或以"/.."结尾, 则作为非法请求处理    if(strncmp(p, "/../", 4) == 0 || strcmp(p, "/..") == 0)      return -2;    p = strpbrk(p+1, "/");  }  //获取文件信息, 失败则返回-1  if(stat(fname, &st) < 0)    return -1;  //若文件类型不是普通文件, 则作为非法请求  if(!S_ISREG(st.st_mode))    return -2;  return open(fname, O_RDONLY);}/** 从fd读取最多len个字节到buf** 参数:*     int fd: 描述字*     char* buf 读缓冲区*     int len: 要读的长度** 返回:*     int: 成功返回实际读到的字节数, 失败返回-1*/static int readn(int fd, char* buf, int len){  int nread, nleft = len;  char* pb = buf;  while(nleft>0)  {    if((nread = read(fd, pb, nleft)) < 0)      if(errno == EINTR)        continue;      else        return -1;    if(nread == 0)      break;    nleft -= nread;    pb += nread;  }  return len-nleft;}/** 从file读取数据, 并把"\n"转化为"\r\n"后, 放到buf里** 参数:*     FILE* file:  要读的文件流*     char* buf 读缓冲区*     int len: 缓冲区大小** 返回:*     int: 返回实际读到的字节数*/static int readnn(FILE* file, char* buf, int len){  int slen, nleft = len;  char* pb = buf;  char c;  static bool isLF = false;//标记上次在最后读到的是不是LF  //如果上次在最后读到LF,  //则写一个LF到缓冲区  if(isLF)  {    isLF = false;    buf[0] = '\n';    --nleft;    ++pb;  }  //读数据到缓冲区  while(nleft > 0)  {    //EOF则跳出    if(fgets(pb, nleft, file) == NULL)      break;    slen = strlen(pb);    pb += slen;    nleft -= slen;    //读到"\n"则转换为"\r\n"    if(slen > 0 && *(pb-1) == '\n')    {      //肯定有空间放下2个字节      memcpy(pb-1, "\r\n", 2);      --nleft;      ++pb;    }    //最后一个字节    if(nleft == 1)    {      //如果最后的字符是LF, 则写一个CR到缓冲区,      //并打开isLF标记      if((c = fgetc(file)) == '\n')      {          *pb = '\r';          nleft = 0;          isLF = true;      }      //如果最后的字符不是LF, 则直接写到到缓冲区,      else if(c != EOF)      {        *pb = c;        nleft = 0;      }      break;    }  }  return len - nleft;}/** 启动子服务器** 参数:*     const sockaddr* pclinaddr: 客户IP地址结构的指针*     socklen_t clinlen: 客户IP地址的长度*     const char* fname: 客户请求的文件*     bool isbin: 标记二进制传输模式** 返回:*    unsigned char: 成功则返回0, 出错返回1或2或3*/static unsigned char startserv(const sockaddr* pclinaddr, socklen_t clinlen, const char* fname, bool isbin){  int clinfd, fd; //UDP套接字, 文件描述字  FILE* file = NULL; //用于netascii模式的文件流  int nread, scount;  unsigned short ack[2]; //存放客户应答  timeval tmout; //超时时间  fd_set fdset;  FD_ZERO(&fdset);  //连接客户程序  if((clinfd = socket(AF_INET, SOCK_DGRAM, 0)) < 0)    exit(1);  if(connect(clinfd, pclinaddr, clinlen) < 0)    exit(1);  //打开文件, 出错则发送错误报文, 并返回  fd = openfile(fname);  if(fd == -1)  {    initmsg(buf, ERR, 1, "file not found");    write(clinfd, buf, strlen(pdata)+5);    return 1;  }  else if(fd == -2)  {    initmsg(buf, ERR, 2, "invalid access");    write(clinfd, buf, strlen(pdata)+5);    return 2;  }  //为netascii模式生成文件流, 失败则  //发送错误报文, 并返回  if((!isbin) && (file = fdopen(fd, "r")) == NULL)  {    initmsg(buf, ERR, 0, "can't create FILE");    write(clinfd, buf, strlen(pdata)+5);    return 1;  }  //读取文件数据, 并发送给客户  for(unsigned short index = 1; ; ++index)  {    //根据不同的模式读取数据到缓冲区    //octet模式    if(isbin)    {      if((nread = readn(fd, pdata, DSIZE)) < 0)      {        initmsg(buf, ERR, 0, "read error");        write(clinfd, buf, strlen(pdata)+5);        return 3;      }    }    //netascii模式    else      nread = readnn(file, pdata, DSIZE);    //发送一段数据给客户,直到成功跳出或    //超过最大重发次数退出或错误退出    for(int i = 0; i <= MAXROUND; ++i)    {      //超过最大重发次数则退出      if(i == MAXROUND)        return 3;      //发送数据, 错误则退出      if(write(clinfd, initmsg(buf, DATA, index), nread+4) < nread+4)        return 3;      //准备进行select      FD_SET(clinfd, &fdset);      tmout = TMOUT;      scount = select(clinfd+1, &fdset, NULL, NULL, &tmout);      //select错误      if(scount < 0)      {        if(errno == EINTR)        {          --i;          continue;        }        else          return 3;      }      //select超时      else if(scount == 0)        continue;      //此时可以确定clinfd已就绪(不考虑套接口错误)      else      {        int n = read(clinfd, ack, sizeof(ack));        //接收到正确应答, 则跳出        if(n == sizeof(ack) && ntohs(ack[0]) == ACK && ntohs(ack[1]) == index)          break;      }    }    //到文件尾则跳出    if(nread < DSIZE)      break;    //编号溢出则重置    if(index == (unsigned short)0xffff)      index = 0;  }  return 0;}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -