⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 sslsocket.cpp

📁 一个简单实用的开源C++消息中间件SAFMQ - [软件开发] - [开源 消息中间件 SAFMQ ] 2006-11-23 在很多网络应用中
💻 CPP
📖 第 1 页 / 共 2 页
字号:
*/
size_t SSLSocket::receive(char* readBuffer, size_t length)		throw (SSLSocketException)
{
	int	nread = 0;
	int	iread;

	while ( nread < length ) {
		pdata->mrc = iread = SSL_read(pdata->ssl,readBuffer+nread,length-nread);
		if (iread == 0) {
			if (m_throwOnClose)
				throw SocketException(ECONNRESET);
			break;
		} else if (iread < 0) {
			int err = SSL_get_error(pdata->ssl,iread);
			if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
				sleep(10);
			} else
				throw SSLSocketException(getErrorMessage().c_str());
		} else {
			nread += iread;
		}
	}
	return nread;
}

/**
Receives data from the socket.  Receives upto <code>length</code> bytes.

@param sendBuffer [out] A buffer receiving the data 
@param length [in] The maximum number of bytes in the buffer
@return The number of bytes received, -1 on error, and 0 if the connection was closed
@exception SocketException if the connection was closed and the throw on close flag was set.
*/
size_t SSLSocket::receiveSome(char* readBuffer, size_t length)	throw (SSLSocketException)
{
	int	nread = 0;
	int	iread;
	
	while ( nread == 0 ) {
		pdata->mrc = iread = SSL_read(pdata->ssl,readBuffer+nread,length-nread);
		if (iread == 0) {
			int err = SOCERRNO;
			if (m_throwOnClose)
				throw SocketException(ECONNRESET);
			break;
		} else if (iread < 0) {
			int err = SSL_get_error(pdata->ssl,iread);
			if (err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE) {
				sleep(10);
			} else
				throw SSLSocketException(getErrorMessage().c_str());
		} else {
			nread += iread;
		}
	}
	return nread;
}

/**
Closes the socket handle, and frees the SSL structures
*/
void SSLSocket::close() throw (SocketException)
{
	if (pdata->ssl) 
		SSL_free(pdata->ssl);
	if (pdata->ctx)
		SSL_CTX_free(pdata->ctx);
	pdata->ssl = NULL;
	pdata->ctx = NULL;

	Socket::close();
}

/**
Clones this object.
@return a new SSLSocket constructed from this object.
*/
Socket* SSLSocket::clone() const 
{
	return new SSLSocket(*this);
}
/**
Sets the SSL certificate from a file.

@param fname [in] The name of the certificate file
@param type [in] The type of the file
*/
bool SSLSocket::setCertificateFile(const char* fname, SSLSocket::file_type type)
{
	if (type == ssl_pem) {
		if (pdata->ctx) {
			pdata->mrc = SSL_CTX_use_certificate_file(pdata->ctx, fname, SSL_FILETYPE_PEM);
			if (pdata->mrc < 1)
				return false;
		}
		if (pdata->ssl) {
			pdata->mrc = SSL_use_certificate_file(pdata->ssl, fname, SSL_FILETYPE_PEM);
			if (pdata->mrc < 1)
				return false;
		}
	} else if (type == ssl_asn1) {
		if (pdata->ctx) {
			pdata->mrc = SSL_CTX_use_certificate_file(pdata->ctx, fname, SSL_FILETYPE_ASN1);
			if (pdata->mrc < 1)
				return false;
		}
		if (pdata->ssl) {
			pdata->mrc = SSL_use_certificate_file(pdata->ssl, fname, SSL_FILETYPE_ASN1);
			if (pdata->mrc < 1)
				return false;
		}
	}
	return pdata->mrc == 1;
}

/**
Sets the SSL private key from a file.

@param fname [in] The name of the private key file
@param type [in] The type of the file
*/
bool SSLSocket::setPrivateKeyFile(const char* fname, SSLSocket::file_type type)
{
	if (type == ssl_pem) {
		if (pdata->ctx) {
			pdata->mrc = SSL_CTX_use_PrivateKey_file(pdata->ctx, fname, SSL_FILETYPE_PEM);
			if (pdata->mrc < 1)
				return false;
		}
		if (pdata->ssl) {
			pdata->mrc = SSL_use_PrivateKey_file(pdata->ssl, fname, SSL_FILETYPE_PEM);
			if (pdata->mrc < 1)
				return false;
		}
	} else if (type == ssl_asn1) {
		if (pdata->ctx) {
			pdata->mrc = SSL_CTX_use_PrivateKey_file(pdata->ctx, fname, SSL_FILETYPE_ASN1);
			if (pdata->mrc < 1)
				return false;
		}
		if (pdata->ssl) {
			pdata->mrc = SSL_use_PrivateKey_file(pdata->ssl, fname, SSL_FILETYPE_ASN1);
			if (pdata->mrc < 1)
				return false;
		}
	}

	return pdata->mrc != 0;
}

/**
Retreives the certificate of the peer.
@return the X509 cetificate of the peer.
*/
X509Certificate* SSLSocket::getPeerCertificate()
{
	if (pdata->ssl) {
		X509*	cert = SSL_get_peer_certificate(pdata->ssl);
		if (cert == NULL) {
			if (!SSL_is_init_finished(pdata->ssl)) {
				if (SSL_in_connect_init(pdata->ssl)) {
					pdata->mrc = SSL_connect(pdata->ssl);
				} else if (SSL_in_accept_init(pdata->ssl)) {
					pdata->mrc = SSL_accept(pdata->ssl);
				}
				cert = SSL_get_peer_certificate(pdata->ssl);
				if (cert != NULL) {
					X509Data	dcert(cert);
					return new X509Certificate(&dcert);
				}
			}
		}
	}
	return NULL;
}

/**
Gets the result of performing a certificate verification.
@return The verify result.
*/
int SSLSocket::getVerifyResult()
{
	return SSL_get_verify_result(pdata->ssl);
}

static int error_print(const char* str, size_t len, void* u)
{
	((std::string*)u)->append(str);
	return 1;
}
/**
Gets the SSL error message
@return The SSL error message
*/
std::string SSLSocket::getErrorMessage()
{
	std::string ret;
	if (pdata->ssl) {
		ERR_print_errors_cb(error_print,&ret);
	}
	return ret;
}
/**
Gets the SSL object cast as void*.
@return the SSL object cast as void*.
*/
void* SSLSocket::getSSL()
{
	return pdata->ssl;
}
/**
Gets the CTX object cast as void*.
@return the CTX object cast as void*.
*/
void* SSLSocket::getCTX()
{
	return pdata->ctx;
}

void SSLSocket::init()
{
	pdata = new SSLSocketData();
}

void SSLSocket::init_ssl(enum SSLSocket::ssl_ver version)
{
	switch (version) {
		case sslv2:
			pdata->ctx = SSL_CTX_new (SSLv2_method());
			if(pdata->ctx)
				pdata->ssl = SSL_new(pdata->ctx);
			break;
		case sslv3:
			pdata->ctx = SSL_CTX_new (SSLv3_method());
			if(pdata->ctx)
				pdata->ssl = SSL_new(pdata->ctx);
			break;
		case sslv23:
			pdata->ctx = SSL_CTX_new (SSLv23_method());
			if(pdata->ctx)
				pdata->ssl = SSL_new(pdata->ctx);
			break;
	}
}

// /////////////////////////////////////////////////////////////////////////////

/**
Constructs the server socket binding to the port and optionally using the bind address and listener backlog
@param bindPort [in] The TCP port to bind to
@param bindAddress [in,optional] Optional address to bind to, if not specified all local address
									will be bound.
@param listenerBacklog [in,optional] Optional depth of the listener queue.
@throw SocketException on an error
*/
SSLServerSocket::SSLServerSocket(short bindPort, ssl_ver version, unsigned long bindAddress, int listenerBacklog) throw (SocketException)
{
	init(bindPort,version,bindAddress,listenerBacklog);
}

/**
Constructs the server socket binding to the port and optionally using the bind address and listener backlog
@param bindPort [in] The TCP port to bind to
@param key_file [in] The name of the Private Key file
@param key_type [in] The type of the private key file
@param cert_file [in] The name of the certificate file
@param cert_type [in] The type of the certificate file
@param bindAddress [in,optional] Optional address to bind to, if not specified all local address
									will be bound.
@param listenerBacklog [in,optional] Optional depth of the listener queue.
@throw SocketException on an error
*/
SSLServerSocket::SSLServerSocket(short bindPort, const char* key_file, SSLSocket::file_type key_type, const char* cert_file, SSLSocket::file_type cert_type, SSLSocket::ssl_ver version, unsigned long bindAddress, int listenerBacklog) throw (SocketException)
{
	init(bindPort,version,bindAddress,listenerBacklog);
	setPrivateKeyFile(key_file,key_type);
	setCertificateFile(cert_file,cert_type);
}
/**
Releases the server socket
@throw SocketException on an error
*/
SSLServerSocket::~SSLServerSocket() throw (SocketException)
{
	close();
}

/**
Accepts a connection from the server socket.
@return A socket connected to a client
@throw SocketException on an error
*/
SSLSocket SSLServerSocket::acceptConnection() throw (SocketException)
{
	int	soc = ::accept(m_socket,NULL,NULL);
	if (soc < 0)
		throw SocketException(SOCERRNO);

	SSLSocketData	data;
	data.ssl = SSL_new(pdata->ctx);
	return SSLSocket(soc,&data,con_server);
}

/**
Closes the server socket.  Can be called out of band to cause acceptConnection()
to return immediately.
*/
void SSLServerSocket::close() throw (SocketException) {
	Socket::close();
	m_socket = -1; 
}

static int verify_callback(int ok, X509_STORE_CTX* ctx)
{
	return 1;
}

/**
Call to enable verification routienes
*/
void SSLServerSocket::enableVerification()
{
	if (pdata->ctx) {
		SSL_CTX_set_verify(pdata->ctx, SSL_VERIFY_PEER|SSL_VERIFY_CLIENT_ONCE, verify_callback);
	}
}
/**
Sets the SSL certificate from a file.

@param fname [in] The name of the certificate file
@param type [in] The type of the file
*/
bool SSLServerSocket::setCertificateFile(const char* fname, file_type type)
{
	return SSLSocket::setCertificateFile(fname,type);
}
/**
Sets the SSL private key from a file.

@param fname [in] The name of the private key file
@param type [in] The type of the file
*/
bool SSLServerSocket::setPrivateKeyFile(const char* fname, file_type type)
{
	return SSLSocket::setPrivateKeyFile(fname,type);
}

/**
Gets the CTX object cast as void*.
@return the CTX object cast as void*.
*/
void* SSLServerSocket::getCTX()
{
	return SSLSocket::getCTX();
}

void SSLServerSocket::init(short bindPort, SSLSocket::ssl_ver version, unsigned long bindAddress, int listenerBacklog) throw (SocketException)
{
	struct sockaddr_in	saddr;

	m_socket = ::socket(AF_INET, SOCK_STREAM, 0);

	memset(&saddr, 0, sizeof(saddr));
	saddr.sin_family = AF_INET;
	saddr.sin_port = htons(bindPort);
	saddr.sin_addr	= *(struct in_addr*)&bindAddress;

	int	iReuse = 1;
	::setsockopt(m_socket, SOL_SOCKET, SO_REUSEADDR, (const char*) &iReuse, sizeof(int));

	if (::bind(m_socket, (struct sockaddr*)&saddr, sizeof(saddr)) != 0) {
		int err = SOCERRNO;
		close();
		throw SocketException(err);
	}

	if (::listen(m_socket, listenerBacklog) != 0) {
		int err = SOCERRNO;
		close();
		throw SocketException(err);
	}

	init_ssl(version);
}

void SSLServerSocket::init_ssl(SSLSocket::ssl_ver version)
{
	switch (version) {
		case sslv2:
			pdata->ctx = SSL_CTX_new (SSLv2_method());
			break;
		case sslv3:
			pdata->ctx = SSL_CTX_new (SSLv3_method());
			break;
		case sslv23:
			pdata->ctx = SSL_CTX_new (SSLv23_method());
			break;
	}
}


}

⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -