⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 iocp.h

📁 基于完成端口的TCP网络通信框架实现 工程iocp中包含了框架实现的所有代码
💻 H
字号:
#pragma once
#include <winsock2.h>
#include <windows.h>
#include <assert.h>
#include <list>
#include <vector>
#include <map>
#include "iocpapi.h"
using std::list;
using std::vector;
using std::map;
using std::make_pair;

#define for_each_iter(container,iter) \
  for(iter=container.begin();iter!=container.end();++iter)

#define for_each_iter_safe(container,iter,iterNext) \
  for(iter=container.begin(),iterNext=iter,++iterNext; \
      iter!=container.end(); \
      iter=iterNext,++iterNext)
class SysMutex
{
public:
  virtual ~SysMutex(){}
  virtual void Lock() = 0;
  virtual void Unlock() = 0;
};

class Locker
{
public:
  explicit Locker(SysMutex *mutex){m_mutex = mutex;m_mutex->Lock();}
  ~Locker(){m_mutex->Unlock();}
private:
  SysMutex *m_mutex;
};

class UserMutex: public SysMutex
{
public:
  UserMutex(){InitializeCriticalSection(&m_cs);}
  ~UserMutex(){DeleteCriticalSection(&m_cs);}
  void Lock(){EnterCriticalSection(&m_cs);}
  void Unlock(){LeaveCriticalSection(&m_cs);}
private:
  CRITICAL_SECTION m_cs;
};

class SysEvent
{
public:
  SysEvent(bool bManualReset = false){m_hEvent=CreateEvent(NULL, bManualReset, false, NULL);}
  ~SysEvent(){CloseHandle(m_hEvent);}
  void Set(){SetEvent(m_hEvent);}
  bool Wait(unsigned int ms = INFINITE){return WAIT_OBJECT_0 == WaitForSingleObject(m_hEvent, ms);}
  void Reset(){ResetEvent(m_hEvent);}
private:
  HANDLE m_hEvent;
};

//NET_BLOCK_SIZE : We can define NET_BLOCK_SIZE to any positive number
//Why I chose 4096-64 is based on that the PerIoData structure could
//be placed in a single memory page
const int NET_BLOCK_SIZE=4096-64;

enum NET_OPTYPE;
struct PerIoData
{
  // These fields are used by the implementation 
  // of NetIOCP... classes
  OVERLAPPED overlapped;
  WSABUF databuf;
  NET_OPTYPE op_type;
  SOCKET sock;

  // These fields can be used by the classes' users
  int len;
  char data[NET_BLOCK_SIZE];
};

class NetIOCPBase
{
public:
  NetIOCPBase();
  virtual ~NetIOCPBase();
  void start();
  void stop();
  void runloop();
  void senddata(SOCKET s,PerIoData *iodata);
  void close(SOCKET s);
protected:
  virtual void net_open_handler(SOCKET s)=0;
  virtual void net_closed_handler(SOCKET s)=0;
  virtual void net_data_handler(PerIoData *data)=0;
  virtual void start_run(){}
  virtual void stop_run(){}
  HANDLE miocp;
private:
  static unsigned int __stdcall completion_port_worker_thread(void *cookie);
  static int ref;
  void worker_func();

  void aio_send(SOCKET s,PerIoData *iodata);
  void aio_recv(SOCKET s,PerIoData *iodata);
  void aio_resend(PerIoData *iodata,int lastlen);
  void wkshutdown();
  
  vector<HANDLE> mwthread;
  bool started;

  bool create_sendqueue(SOCKET s);
  bool delete_sendqueue(SOCKET s);
  struct Channel
  {
    SOCKET sock;
    bool busy;
    list<PerIoData *> sendqueue;
  };
  map<SOCKET,Channel *> mmsend_pool;
  UserMutex mmsend_pool_mutex;

  //worker发送到主线程的通告(网络事件)
  enum NET_EVENT
  {
    NE_SOCKNEW,    //新socket连接
    NE_SOCKCLOSED, //socket已经被关闭
    NE_DATARECVED, //收到数据
  };
  struct NetEvent
  {
    NET_EVENT event;
    PerIoData *iodata;
  };
  list<NetEvent> mnequeue;
  UserMutex mnequeue_mutex;
  list<NetEvent> mnequeue_main;
  void post_netevent(const NetEvent &ne)
  {
    Locker lock(&mnequeue_mutex);
    mnequeue.push_back(ne);
  }
  void recv_allnetevents()
  {
    Locker lock(&mnequeue_mutex);
    mnequeue_main.splice(mnequeue_main.end(),mnequeue);
  }
  bool peek_netevent(NetEvent *ne)
  {
    if(!mnequeue_main.empty())
    {
      *ne=mnequeue_main.front();
      mnequeue_main.pop_front();
      return true;
    }
    return false;
  }
};

/**
  class NetIOCPServer
  The basic server network data transfer class
  To use it you just 
   1  Inherit NetIOCPServer of your own class like
      class NetIOCPTestServer:public NetIOCPServer
   2  And rewrite your own pure virtual functions
      net_open_handler   //when new client come this function will be called
      net_closed_handler //when client has been closed this function will be called
      net_data_handler   //when data arrived this function will be called
   3  And call the framework:
      //...
      NetIOCPTestServer *server=new NetIOCPTestServer(listenport);
      server->start();
      //the main loop
      while(1)
      {
        //...
        server->runloop();
      }
      delete server;
   4  When you want to send data to somewhere
      first prepare your data,like
      PerIodata *iodata=new PerIoData;
      iodata->len=yourlen;
      memcpy(iodata->data,yourdata);
      then send it
      server->senddata(destsocket,iodata);
   5  When you want to close some client socket call
      server->close(clientsocket);
      Do not call the system function closesocket()
      directly,otherwise the iocp can not cleanup 
      the resources
*/
class NetIOCPServer:public NetIOCPBase
{
public:
  explicit NetIOCPServer(unsigned short port);
  ~NetIOCPServer();
protected:
  virtual void net_open_handler(SOCKET s)=0;
  virtual void net_closed_handler(SOCKET s)=0;
  virtual void net_data_handler(PerIoData *data)=0;
private:
  static unsigned int __stdcall completion_port_accept_thread(void *cookie);
  void accept_func();
  void start_run();
  void stop_run();

  unsigned short mport;
  HANDLE mathread;
  SOCKET mserver;
};

/**
  class NetIOCPClient
  the client counterpart of NetIOCPServer
  the only different is 
    You call connect to connect some server
*/
class NetIOCPClient:public NetIOCPBase
{
public:
  bool connect(const char *dotip,unsigned short port,SOCKET *s);
protected:
  virtual void net_open_handler(SOCKET s)=0;
  virtual void net_closed_handler(SOCKET s)=0;
  virtual void net_data_handler(PerIoData *data)=0;
};

/**
  class MsgIOCPServer and MsgIOCPClient
  The class create a find the message boundary
  we commonly inherit this class

  Network message format is
  datalen data datalen data...
  
  there is no space between messages or 
  between (datalen and data) 
  
  "datalen" indicate the message size
  That means the size of "data" is datalen
*/

// message structure
// We can define NET_MESSAGE_SIZE to any positive number;
//const int NET_MESSAGE_SIZE=4096-32;
//struct NetMessage
//{
//  int datalen;
//  char data[NET_MESSAGE_SIZE];
//};

struct MsgData
{
  int recv_lenlen;
  int recv_datalen;
  unsigned short len;
  NetMessage *msg;
};
template<class T>
class MsgIOCP:public T
{
public:
  explicit MsgIOCP(unsigned short port)
    :T(port)
  {
  }
  MsgIOCP(){}
  void sendmessage(SOCKET s,NetMessage *msg);
protected:
  virtual void msg_open_handler(SOCKET s)=0;
  virtual void msg_closed_handler(SOCKET s)=0;
  virtual void msg_message_handler(SOCKET s,NetMessage *msg)=0;
  virtual void msg_illegal_handler(SOCKET s,int msglen)=0;
private:
  void net_open_handler(SOCKET s);
  void net_closed_handler(SOCKET s);
  void net_data_handler(PerIoData *iodata);
  map<SOCKET,MsgData> mrecvpot;
  typedef map<SOCKET,MsgData>::iterator MapPotIter;
};

//
// We call this function to send a message
// the should neither be access nor be deleted 
// while after you have called the function 
//
template<class T>
void MsgIOCP<T>::sendmessage(SOCKET s,NetMessage *msg)
{
  assert(msg->datalen<=NET_MESSAGE_SIZE&&msg->datalen>0);

  //message head
  int head_block_num=(sizeof(unsigned short)-1)/NET_BLOCK_SIZE+1;
  unsigned short len=(unsigned short)msg->datalen;
  for(int i=0;i<head_block_num-1;++i)
  {
    PerIoData *iodata=new PerIoData;
    iodata->len=NET_BLOCK_SIZE;
    memcpy(iodata->data,(char *)&len+i*NET_BLOCK_SIZE,NET_BLOCK_SIZE);
    senddata(s,iodata);
  }
  PerIoData *iodata=new PerIoData;
  iodata->len=sizeof(unsigned short)-i*NET_BLOCK_SIZE;
  assert(iodata->len>0&&iodata->len<=NET_BLOCK_SIZE);
  memcpy(iodata->data,(char *)&len+i*NET_BLOCK_SIZE,iodata->len);
  if(iodata->len==NET_BLOCK_SIZE)
  {
    senddata(s,iodata);
    iodata=new PerIoData;
    iodata->len=0;
  }
  if(msg->datalen<=NET_BLOCK_SIZE-iodata->len)
  {
    memcpy(iodata->data+iodata->len,msg->data,msg->datalen);
    iodata->len+=msg->datalen;
    senddata(s,iodata);
    delete msg;
    return;
  }

  int new_block_num=(msg->datalen-(NET_BLOCK_SIZE-iodata->len)-1)/NET_BLOCK_SIZE+1;
  memcpy(iodata->data+iodata->len,msg->data,NET_BLOCK_SIZE-iodata->len);
  int msgdata_index=NET_BLOCK_SIZE-iodata->len;
  iodata->len=NET_BLOCK_SIZE;
  senddata(s,iodata);
  for(int i=0;i<new_block_num-1;++i)
  {
    iodata=new PerIoData;
    iodata->len=NET_BLOCK_SIZE;
    assert(msgdata_index+NET_BLOCK_SIZE<NET_MESSAGE_SIZE);
    memcpy(iodata->data,msg->data+msgdata_index,NET_BLOCK_SIZE);
    senddata(s,iodata);
    msgdata_index+=NET_BLOCK_SIZE;
  }
  assert(msgdata_index<NET_MESSAGE_SIZE);
  iodata=new PerIoData;
  iodata->len=msg->datalen-msgdata_index;
  assert(iodata->len>0);
  memcpy(iodata->data,msg->data+msgdata_index,iodata->len);
  senddata(s,iodata);
  delete msg;
}

template<class T>
void MsgIOCP<T>::net_open_handler(SOCKET s)
{
  pair<MapPotIter,bool> ret;
  MsgData data;
  data.recv_lenlen=0;
  data.recv_datalen=0;
  data.len=0;
  data.msg=new NetMessage;
  ret=mrecvpot.insert(make_pair(s,data));
  if(!ret.second)
  {
    assert(0);
    MapPotIter iter;
    iter=mrecvpot.find(s);
    iter->second.recv_lenlen=0;
    iter->second.recv_datalen=0;
    iter->second.len=0;
  }
  msg_open_handler(s);
}
template<class T>
void MsgIOCP<T>::net_closed_handler(SOCKET s)
{
  MapPotIter iter;
  iter=mrecvpot.find(s);
  if(mrecvpot.end()==iter)
  {
    assert(0);
  }
  else
  {
    delete iter->second.msg;
    mrecvpot.erase(iter);
  }
  msg_closed_handler(s);
}
template<class T>
void MsgIOCP<T>::net_data_handler(PerIoData *iodata)
{
  assert(iodata->len>0&&iodata->len<=NET_BLOCK_SIZE);
  MapPotIter iter;
  iter=mrecvpot.find(iodata->sock);
  if(mrecvpot.end()==iter)
  {
    assert(0);
    MsgData data;
    data.recv_datalen=0;
    data.recv_lenlen=0;
    data.len=0;
    data.msg=new NetMessage;
    mrecvpot.insert(make_pair(iodata->sock,data));
  }
  MsgData *pot=&iter->second;
  int iodata_index=0;
  while(iodata_index<iodata->len)
  {
    //message head
    if(pot->recv_lenlen<sizeof(unsigned short))
    {
      int copy_len=min(iodata->len-iodata_index,sizeof(unsigned short)-pot->recv_lenlen);
      memcpy((char *)&pot->len+pot->recv_lenlen,iodata->data+iodata_index,copy_len);
      pot->recv_lenlen+=copy_len;
      iodata_index+=copy_len;
      if(iodata_index==iodata->len)
      {
        break;
      }
    }
    //illegal message
    if(pot->len<=0||pot->len>NET_MESSAGE_SIZE)
    {
      msg_illegal_handler(iodata->sock,pot->len);
      pot->recv_lenlen=0;
      pot->recv_datalen=0;
      pot->len=0;
      break;
    }
    //message body
    if(pot->recv_datalen<pot->len)
    {
      int copy_len=min(iodata->len-iodata_index,pot->len-pot->recv_datalen);
      memcpy(pot->msg->data+pot->recv_datalen,iodata->data+iodata_index,copy_len);
      pot->recv_datalen+=copy_len;
      iodata_index+=copy_len;
      if(pot->recv_datalen==pot->len)
      {
        pot->msg->datalen=pot->len;
        pot->recv_lenlen=0;
        pot->recv_datalen=0;
        pot->len=0;
        msg_message_handler(iodata->sock,pot->msg);
      }
    }
  }
}

typedef MsgIOCP<NetIOCPServer> MsgIOCPServer;
typedef MsgIOCP<NetIOCPClient> MsgIOCPClient;

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -