⭐ 欢迎来到虫虫下载站! | 📦 资源下载 📁 资源专辑 ℹ️ 关于我们
⭐ 虫虫下载站

📄 server.cpp

📁 实现内存数据库的源代码
💻 CPP
📖 第 1 页 / 共 3 页
字号:
		break;
	    }
	} else { 
	    TRACE_MSG(("Field '%s' not found\n", columnName));
	    response = cli_column_not_found;
	    break;
	}
    }
    return data;
}


bool dbServer::insert(dbSession* session, int stmt_id, char* data, bool prepare)
{
    dbStatement* stmt = findStatement(session, stmt_id);
    dbTableDescriptor* desc = NULL;
    dbColumnBinding* cb;
    int4   response;
    char   reply_buf[sizeof(cli_oid_t) + 8];
    char*  dst;
    oid_t  oid = 0;
    size_t offs;
    int    n_columns;

    if (stmt == NULL) { 
	if (!prepare) { 
	    response = cli_bad_statement;
	    goto return_response;
	}
	stmt = new dbStatement(stmt_id);
	stmt->next = session->stmts;
	session->stmts = stmt;
    } else {
	if (prepare) { 
	    stmt->reset();
	} else if ((desc = stmt->table) == NULL) {
	    response = cli_bad_descriptor;
	    goto return_response;
	}
    }
    if (prepare) { 
	session->scanner.reset(data);
	if (session->scanner.get() != tkn_insert 
	    || session->scanner.get() != tkn_into
	    || session->scanner.get() != tkn_ident) 
	{
	    response = cli_bad_statement;
	    goto return_response;
	}
	desc = db->findTable(session->scanner.ident);
	if (desc == NULL) { 	
	    response = cli_table_not_found;
	    goto return_response;
	}
	data += strlen(data)+1;
	n_columns = *data++;
	data = checkColumns(stmt, n_columns, desc, data, response);
	if (response != cli_ok) { 
	    goto return_response;
	}
	stmt->table = desc;
    }
 
    offs = desc->fixedSize;
    for (cb = stmt->columns; cb != NULL; cb = cb->next) { 
	cb->ptr = data;
	if (cb->cliType == cli_autoincrement) {
	    ;
	} else if (cb->cliType >= cli_asciiz) { 
	    cb->len = unpack4(data);
	    data += 4 + cb->len*cb->fd->components->dbsSize;
	    offs = DOALIGN(offs, cb->fd->components->alignment)
		 + cb->len*cb->fd->components->dbsSize;
	} else { 
	    data += sizeof_type[cb->cliType];
	}
    }
    db->beginTransaction(true);
    db->modified = true;
    oid = db->allocateRow(desc->tableId, offs);
    dst = (char*)db->getRow(oid);    
    
    offs = desc->fixedSize;
    for (cb = stmt->columns; cb != NULL; cb = cb->next) { 
	dbFieldDescriptor* fd = cb->fd;
	if (fd->type == dbField::tpArray || fd->type == dbField::tpString) {
	    offs = DOALIGN(offs, fd->components->alignment);
	    ((dbVarying*)(dst + fd->dbsOffs))->offs = offs;
	    ((dbVarying*)(dst + fd->dbsOffs))->size = cb->len;
	    offs += cb->unpackArray(dst, offs)*fd->components->dbsSize;
	} else { 
	    cb->unpackScalar(dst);
	}
    }
    for (cb = stmt->columns; cb != NULL; cb = cb->next) { 
	if (cb->fd->indexType & HASHED) { 
	    dbHashTable::insert(db, cb->fd->hashTable, oid,
				cb->fd->type, cb->fd->dbsSize, cb->fd->dbsOffs, 0);
	}
	if (cb->fd->indexType & INDEXED) { 
	    dbTtree::insert(db, cb->fd->tTree, oid, 
			    cb->fd->type, cb->fd->dbsSize, cb->fd->comparator, cb->fd->dbsOffs);
	}
    }
    response = cli_ok;
  return_response:
    pack4(reply_buf, response);
    if (desc == NULL) { 
	pack4(reply_buf+4, 0);
    } else { 
#ifdef AUTOINCREMENT_SUPPORT
	pack4(reply_buf+4, desc->autoincrementCount);
#else
	pack4(reply_buf+4, ((dbTable*)db->getRow(desc->tableId))->nRows);
#endif
    }
    pack_oid(reply_buf+8, oid);
    return session->sock->write(reply_buf, sizeof reply_buf);
}    



bool dbServer::select(dbSession* session, int stmt_id, char* msg, bool prepare)
{
    int4 response;
    int i, n_params, tkn, n_columns;
    dbStatement* stmt = findStatement(session, stmt_id);
    dbCursorType cursorType;
    dbTableDescriptor* desc;

    if (prepare) { 
	if (stmt == NULL) { 
	    stmt = new dbStatement(stmt_id);
	    stmt->next = session->stmts;
	    session->stmts = stmt;
	} else { 
	    stmt->reset();
	}
	stmt->n_params = *msg++;
	stmt->n_columns = n_columns = *msg++;
	stmt->params = new dbParameterBinding[stmt->n_params];
	stmt->firstFetch = true;
	int len = unpack2(msg);
	msg += 2;
	session->scanner.reset(msg);
	char *p, *end = msg + len;
	if (session->scanner.get() != tkn_select) { 
	    response = cli_bad_statement;
	    goto return_response;
	}
	if ((tkn = session->scanner.get()) == tkn_all) { 
	    tkn = session->scanner.get();
	}
	if (tkn == tkn_from && session->scanner.get() == tkn_ident) { 
	    if ((desc = db->findTable(session->scanner.ident)) != NULL) { 
		msg = checkColumns(stmt, n_columns, desc, end, response);
		if (response != cli_ok) {
		    goto return_response;
		}
		stmt->cursor = new dbAnyCursor(*desc, dbCursorViewOnly, NULL);
		stmt->cursor->setPrefetchMode(false);
	    } else { 
		response = cli_table_not_found;
		goto return_response;
	    }		
	} else { 
	    response = cli_bad_statement;
	    goto return_response;
	}
	p = session->scanner.p;
	for (i = 0; p < end; i++) { 
	    stmt->query.append(dbQueryElement::qExpression, p);
	    p += strlen(p) + 1;
	    if (p < end) { 
		int cliType = *p++;
		static const dbQueryElement::ElementType type_map[] = { 
		    dbQueryElement::qVarReference, // cli_oid
		    dbQueryElement::qVarBool,      // cli_bool
		    dbQueryElement::qVarInt1,      // cli_int1 
		    dbQueryElement::qVarInt2,      // cli_int2
		    dbQueryElement::qVarInt4,      // cli_int4
		    dbQueryElement::qVarInt8,      // cli_int8
		    dbQueryElement::qVarReal4,     // cli_real4
		    dbQueryElement::qVarReal8,     // cli_real8
		    dbQueryElement::qVarStringPtr, // cli_asciiz
		    dbQueryElement::qVarStringPtr, // cli_pasciiz
		};
		stmt->params[i].type = cliType;
		stmt->query.append(type_map[cliType], &stmt->params[i].u);
	    }
	}
    } else { 
	if (stmt == NULL) { 
	    response = cli_bad_descriptor;
	    goto return_response;
	}
    }
    cursorType = *msg++ ? dbCursorForUpdate : dbCursorViewOnly;
    for (i = 0, n_params = stmt->n_params; i < n_params; i++) { 
	switch (stmt->params[i].type) { 
	  case cli_oid:
	    stmt->params[i].u.oid = unpack_oid(msg);
	    msg += sizeof(cli_oid_t);
	    break;
	  case cli_int1:
	    stmt->params[i].u.i1 = *msg++;
	    break;
	  case cli_int2:
	    msg = unpack2((char*)&stmt->params[i].u.i2, msg);
	    break;
	  case cli_int4:
	    msg = unpack4((char*)&stmt->params[i].u.i4, msg);
	    break;
	  case cli_int8:
	    msg = unpack8((char*)&stmt->params[i].u.i8, msg);
	    break;
	  case cli_real4:
	    msg = unpack4((char*)&stmt->params[i].u.r4, msg);
	    break;
	  case cli_real8:
	    msg = unpack8((char*)&stmt->params[i].u.r8, msg);
	    break;
	  case cli_bool:
	    stmt->params[i].u.b = *msg++;
	    break;
	  case cli_asciiz:
	  case cli_pasciiz:
	    stmt->params[i].u.str = msg;
	    msg += strlen(msg) + 1;
	    break;
	  default:
	    response = cli_bad_statement;
	    goto return_response;	    
	}
    } 
#ifdef THROW_EXCEPTION_ON_ERROR
    try { 
	response = stmt->cursor->select(stmt->query, cursorType);
    } catch (dbException const& x) { 
	response = (x.getErrCode() == dbDatabase::QueryError)
	    ? cli_bad_statement : cli_runtime_error;
    }
#else
    { 
	dbDatabaseThreadContext* ctx = db->threadContext.get();
	ctx->catched = true;
	int errorCode = setjmp(ctx->unwind);
	if (errorCode == 0) { 
	    response = stmt->cursor->select(stmt->query, cursorType);
	} else { 
	    response = (errorCode == dbDatabase::QueryError)
		? cli_bad_statement : cli_runtime_error;
	}
	ctx->catched = false;
    }
#endif	
  return_response:
    pack4(response);
    return session->sock->write(&response, sizeof response);
}


void dbServer::serveClient()
{
    dbStatement *sp, **spp;
    db->attach();
    while (true) {
	dbSession* session; 
	{   
	    dbCriticalSection cs(mutex);
	    do { 
		go.wait(mutex);
		if (cancelWait) { 
		    nIdleThreads -= 1;
		    done.signal();
		    db->detach();
		    return;
		}
	    } while (waitList == NULL);

	    session = waitList;
	    waitList = waitList->next;
	    session->next = activeList;
	    activeList = session;
	    nIdleThreads -= 1;
	    nActiveThreads += 1;
	}
	cli_request req;
	int4 response = cli_ok;
	bool online = true;
	while (online && session->sock->read(&req, sizeof req)) { 
	    req.unpack();
	    int length = req.length - sizeof(req);
	    dbSmallBuffer msg(length);
	    if (length > 0) { 
		if (!session->sock->read(msg, length)) {
		    break;
		}
	    }
	    switch(req.cmd) { 
	      case cli_cmd_close_session:
		db->commit();
		session->in_transaction = false;
		online = false;
		break;
	      case cli_cmd_prepare_and_execute:
		online = select(session, req.stmt_id, msg, true); 
		session->in_transaction = true;
		break;
	      case cli_cmd_execute:
		online = select(session, req.stmt_id, msg, false); 
		break;
	      case cli_cmd_get_first:
		online = get_first(session, req.stmt_id);
		break;
	      case cli_cmd_get_last:
		online = get_last(session, req.stmt_id);
		break;
	      case cli_cmd_get_next:
		online = get_next(session, req.stmt_id);
		break;
	      case cli_cmd_get_prev:
		online = get_prev(session, req.stmt_id);
		break;
	      case cli_cmd_free_statement:
		for (spp = &session->stmts; (sp = *spp) != NULL; spp = &sp->next)
		{
		    if (sp->id == req.stmt_id) { 
			*spp = sp->next;
			delete sp;
			break;
		    }
		}
		break;
	      case cli_cmd_abort:
		db->rollback();
		session->in_transaction = false;
		online = session->sock->write(&response, sizeof response);
		break;
	      case cli_cmd_commit:		
		db->commit();
		session->in_transaction = false;
		online = session->sock->write(&response, sizeof response);
		break;
	      case cli_cmd_update:
		update(session, req.stmt_id, msg);
		break;		
	      case cli_cmd_remove:
		remove(session, req.stmt_id);
		break;		
	      case cli_cmd_prepare_and_insert:
		insert(session, req.stmt_id, msg, true);
		session->in_transaction = true;
		break;		
	      case cli_cmd_insert:
		insert(session, req.stmt_id, msg, false);
		break;		
	    }
	}	
	if (session->in_transaction) { 
	    db->rollback();
	}
	// Finish session
	{   
	    dbCriticalSection cs(mutex);
	    dbSession** spp;
	    delete session->sock;
	    for (spp = &activeList; *spp != session; spp = &(*spp)->next); 
	    *spp = session->next;
	    session->next = freeList;
	    freeList = session;
	    nActiveThreads -= 1;
	    if (cancelSession) { 
		done.signal();
		break;
	    }
	    if (nActiveThreads + nIdleThreads >= optimalNumberOfThreads) {
		break;
	    }
	    nIdleThreads += 1;
	} 
    }
    db->detach();
}

void dbServer::acceptConnection(socket_t* acceptSock)
{
    while (true) { 
	socket_t* sock = acceptSock->accept();
	dbCriticalSection cs(mutex);
	if (cancelAccept) { 
	    return;
	}
	if (sock != NULL) { 
	    if (freeList == NULL) { 
		freeList = new dbSession;
		freeList->next = NULL;
	    }
	    dbSession* session = freeList;
	    freeList = session->next;
	    session->sock = sock;
	    session->stmts = NULL;
	    session->next = waitList;
	    session->in_transaction = false;
	    waitList = session;
	    if (nIdleThreads == 0) { 
		dbThread thread;
		nIdleThreads = 1;
		thread.create(serverThread, this);
		thread.detach();
	    }
	    go.signal();
	}
    }
}

dbServer::~dbServer()
{
    dbServer** spp;
    for (spp = &chain; *spp != this; spp = &(*spp)->next);
    *spp = next;
    delete globalAcceptSock;
    delete localAcceptSock;
    delete[] URL;
}





⌨️ 快捷键说明

复制代码 Ctrl + C
搜索代码 Ctrl + F
全屏模式 F11
切换主题 Ctrl + Shift + D
显示快捷键 ?
增大字号 Ctrl + =
减小字号 Ctrl + -