diff --git a/lib/network/NetworkConnection.cpp b/lib/network/NetworkConnection.cpp index f8a0ae9f1..ebeb06778 100644 --- a/lib/network/NetworkConnection.cpp +++ b/lib/network/NetworkConnection.cpp @@ -26,7 +26,7 @@ void NetworkConnection::start() boost::asio::async_read(*socket, readBuffer, boost::asio::transfer_exactly(messageHeaderSize), - [this](const auto & ec, const auto & endpoint) { onHeaderReceived(ec); }); + [self = shared_from_this()](const auto & ec, const auto & endpoint) { self->onHeaderReceived(ec); }); } void NetworkConnection::onHeaderReceived(const boost::system::error_code & ec) @@ -42,7 +42,7 @@ void NetworkConnection::onHeaderReceived(const boost::system::error_code & ec) boost::asio::async_read(*socket, readBuffer, boost::asio::transfer_exactly(messageSize), - [this, messageSize](const auto & ec, const auto & endpoint) { onPacketReceived(ec, messageSize); }); + [self = shared_from_this(), messageSize](const auto & ec, const auto & endpoint) { self->onPacketReceived(ec, messageSize); }); } uint32_t NetworkConnection::readPacketSize() @@ -54,7 +54,7 @@ uint32_t NetworkConnection::readPacketSize() readBuffer.sgetn(reinterpret_cast(&messageSize), sizeof(messageSize)); if (messageSize > messageMaxSize) - throw std::runtime_error("Invalid packet size!"); + listener.onDisconnected(shared_from_this(), "Invalid packet size!"); return messageSize; } @@ -88,7 +88,16 @@ void NetworkConnection::sendPacket(const std::vector & message) boost::asio::write(*socket, boost::asio::buffer(messageSize), ec ); boost::asio::write(*socket, boost::asio::buffer(message), ec ); - // FIXME: handle error? + if (ec) + listener.onDisconnected(shared_from_this(), ec.message()); +} + +void NetworkConnection::close() +{ + boost::system::error_code ec; + socket->close(ec); + + //NOTE: ignoring error code } VCMI_LIB_NAMESPACE_END diff --git a/lib/network/NetworkConnection.h b/lib/network/NetworkConnection.h index 76b032bf2..e445329d1 100644 --- a/lib/network/NetworkConnection.h +++ b/lib/network/NetworkConnection.h @@ -13,7 +13,7 @@ VCMI_LIB_NAMESPACE_BEGIN -class NetworkConnection : public INetworkConnection, public std::enable_shared_from_this +class NetworkConnection : public INetworkConnection, std::enable_shared_from_this { static const int messageHeaderSize = sizeof(uint32_t); static const int messageMaxSize = 64 * 1024 * 1024; // arbitrary size to prevent potential massive allocation if we receive garbage input @@ -31,6 +31,7 @@ public: NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr & socket); void start(); + void close() override; void sendPacket(const std::vector & message) override; }; diff --git a/lib/network/NetworkInterface.h b/lib/network/NetworkInterface.h index ba1752d56..58bd11023 100644 --- a/lib/network/NetworkInterface.h +++ b/lib/network/NetworkInterface.h @@ -17,6 +17,7 @@ class DLL_LINKAGE INetworkConnection : boost::noncopyable public: virtual ~INetworkConnection() = default; virtual void sendPacket(const std::vector & message) = 0; + virtual void close() = 0; }; using NetworkConnectionPtr = std::shared_ptr; @@ -38,8 +39,6 @@ class DLL_LINKAGE INetworkServer : boost::noncopyable public: virtual ~INetworkServer() = default; - virtual void sendPacket(const std::shared_ptr &, const std::vector & message) = 0; - virtual void closeConnection(const std::shared_ptr &) = 0; virtual void start(uint16_t port) = 0; }; diff --git a/lib/network/NetworkServer.cpp b/lib/network/NetworkServer.cpp index 676764c3e..b408ed525 100644 --- a/lib/network/NetworkServer.cpp +++ b/lib/network/NetworkServer.cpp @@ -28,7 +28,7 @@ void NetworkServer::start(uint16_t port) void NetworkServer::startAsyncAccept() { auto upcomingConnection = std::make_shared(*io); - acceptor->async_accept(*upcomingConnection, std::bind(&NetworkServer::connectionAccepted, this, upcomingConnection, _1)); + acceptor->async_accept(*upcomingConnection, [this, upcomingConnection](const auto & ec) { connectionAccepted(upcomingConnection, ec); }); } void NetworkServer::connectionAccepted(std::shared_ptr upcomingConnection, const boost::system::error_code & ec) @@ -46,27 +46,12 @@ void NetworkServer::connectionAccepted(std::shared_ptr upcomingCo startAsyncAccept(); } -void NetworkServer::sendPacket(const std::shared_ptr & connection, const std::vector & message) -{ - connection->sendPacket(message); -} - -void NetworkServer::closeConnection(const std::shared_ptr & connection) -{ - logNetwork->info("Closing connection!"); - assert(connections.count(connection)); - connections.erase(connection); -} - void NetworkServer::onDisconnected(const std::shared_ptr & connection, const std::string & errorMessage) { logNetwork->info("Connection lost! Reason: %s", errorMessage); assert(connections.count(connection)); - if (connections.count(connection)) // how? Connection was explicitly closed before? - { - connections.erase(connection); - listener.onDisconnected(connection, errorMessage); - } + connections.erase(connection); + listener.onDisconnected(connection, errorMessage); } void NetworkServer::onPacketReceived(const std::shared_ptr & connection, const std::vector & message) diff --git a/lib/network/NetworkServer.h b/lib/network/NetworkServer.h index 805adfba5..8fc0e8988 100644 --- a/lib/network/NetworkServer.h +++ b/lib/network/NetworkServer.h @@ -29,9 +29,6 @@ class NetworkServer : public INetworkConnectionListener, public INetworkServer public: NetworkServer(INetworkServerListener & listener, const std::shared_ptr & context); - void sendPacket(const std::shared_ptr &, const std::vector & message) override; - void closeConnection(const std::shared_ptr &) override; - void start(uint16_t port) override; }; diff --git a/lib/serializer/CMemorySerializer.cpp b/lib/serializer/CMemorySerializer.cpp index fd7346793..13f3a59cf 100644 --- a/lib/serializer/CMemorySerializer.cpp +++ b/lib/serializer/CMemorySerializer.cpp @@ -17,7 +17,7 @@ int CMemorySerializer::read(std::byte * data, unsigned size) if(buffer.size() < readPos + size) throw std::runtime_error(boost::str(boost::format("Cannot read past the buffer (accessing index %d, while size is %d)!") % (readPos + size - 1) % buffer.size())); - std::memcpy(data, buffer.data() + readPos, size); + std::copy_n(buffer.data() + readPos, size, data); readPos += size; return size; } @@ -26,7 +26,7 @@ int CMemorySerializer::write(const std::byte * data, unsigned size) { auto oldSize = buffer.size(); //and the pos to write from buffer.resize(oldSize + size); - std::memcpy(buffer.data() + oldSize, data, size); + std::copy_n(data, size, buffer.data() + oldSize); return size; } diff --git a/lib/serializer/CMemorySerializer.h b/lib/serializer/CMemorySerializer.h index 510012d29..caaa4cc24 100644 --- a/lib/serializer/CMemorySerializer.h +++ b/lib/serializer/CMemorySerializer.h @@ -18,7 +18,7 @@ VCMI_LIB_NAMESPACE_BEGIN class DLL_LINKAGE CMemorySerializer : public IBinaryReader, public IBinaryWriter { - std::vector buffer; + std::vector buffer; size_t readPos; //index of the next byte to be read public: diff --git a/lobby/LobbyDatabase.cpp b/lobby/LobbyDatabase.cpp index 47ad5b88a..308089631 100644 --- a/lobby/LobbyDatabase.cpp +++ b/lobby/LobbyDatabase.cpp @@ -149,12 +149,6 @@ void LobbyDatabase::prepareStatements() WHERE roomID = ? )"; - static const std::string setGameRoomPlayerLimitText = R"( - UPDATE gameRooms - SET playerLimit = ? - WHERE roomID = ? - )"; - // SELECT FROM static const std::string getRecentMessageHistoryText = R"( @@ -221,7 +215,7 @@ void LobbyDatabase::prepareStatements() static const std::string isAccountCookieValidText = R"( SELECT COUNT(accountID) FROM accountCookies - WHERE accountID = ? AND cookieUUID = ? AND strftime('%s',CURRENT_TIMESTAMP)- strftime('%s',creationTime) < ? + WHERE accountID = ? AND cookieUUID = ? )"; static const std::string isGameRoomCookieValidText = R"( @@ -269,7 +263,6 @@ void LobbyDatabase::prepareStatements() setAccountOnlineStatement = database->prepare(setAccountOnlineText); setGameRoomStatusStatement = database->prepare(setGameRoomStatusText); - setGameRoomPlayerLimitStatement = database->prepare(setGameRoomPlayerLimitText); getRecentMessageHistoryStatement = database->prepare(getRecentMessageHistoryText); getIdleGameRoomStatement = database->prepare(getIdleGameRoomText); @@ -352,11 +345,6 @@ void LobbyDatabase::setGameRoomStatus(const std::string & roomID, LobbyRoomState setGameRoomStatusStatement->executeOnce(vstd::to_underlying(roomStatus), roomID); } -void LobbyDatabase::setGameRoomPlayerLimit(const std::string & roomID, uint32_t playerLimit) -{ - setGameRoomPlayerLimitStatement->executeOnce(playerLimit, roomID); -} - void LobbyDatabase::insertPlayerIntoGameRoom(const std::string & accountID, const std::string & roomID) { insertGameRoomPlayersStatement->executeOnce(roomID, accountID); @@ -392,11 +380,10 @@ void LobbyDatabase::insertAccessCookie(const std::string & accountID, const std: insertAccessCookieStatement->executeOnce(accountID, accessCookieUUID); } -void LobbyDatabase::updateAccessCookie(const std::string & accountID, const std::string & accessCookieUUID) {} - -void LobbyDatabase::updateAccountLoginTime(const std::string & accountID) {} - -void LobbyDatabase::updateActiveAccount(const std::string & accountID, bool isActive) {} +void LobbyDatabase::updateAccountLoginTime(const std::string & accountID) +{ + assert(0); +} std::string LobbyDatabase::getAccountDisplayName(const std::string & accountID) { @@ -410,16 +397,16 @@ std::string LobbyDatabase::getAccountDisplayName(const std::string & accountID) return result; } -LobbyCookieStatus LobbyDatabase::getGameRoomCookieStatus(const std::string & accountID, const std::string & accessCookieUUID, std::chrono::seconds cookieLifetime) -{ - return {}; -} +//LobbyCookieStatus LobbyDatabase::getGameRoomCookieStatus(const std::string & accountID, const std::string & accessCookieUUID) +//{ +// return {}; +//} -LobbyCookieStatus LobbyDatabase::getAccountCookieStatus(const std::string & accountID, const std::string & accessCookieUUID, std::chrono::seconds cookieLifetime) +LobbyCookieStatus LobbyDatabase::getAccountCookieStatus(const std::string & accountID, const std::string & accessCookieUUID) { bool result = false; - isAccountCookieValidStatement->setBinds(accountID, accessCookieUUID, cookieLifetime.count()); + isAccountCookieValidStatement->setBinds(accountID, accessCookieUUID); if(isAccountCookieValidStatement->execute()) isAccountCookieValidStatement->getColumns(result); isAccountCookieValidStatement->reset(); @@ -429,6 +416,7 @@ LobbyCookieStatus LobbyDatabase::getAccountCookieStatus(const std::string & acco LobbyInviteStatus LobbyDatabase::getAccountInviteStatus(const std::string & accountID, const std::string & roomID) { + assert(0); return {}; } diff --git a/lobby/LobbyDatabase.h b/lobby/LobbyDatabase.h index 43caeae3e..8825b9063 100644 --- a/lobby/LobbyDatabase.h +++ b/lobby/LobbyDatabase.h @@ -62,7 +62,6 @@ public: void setAccountOnline(const std::string & accountID, bool isOnline); void setGameRoomStatus(const std::string & roomID, LobbyRoomState roomStatus); - void setGameRoomPlayerLimit(const std::string & roomID, uint32_t playerLimit); void insertPlayerIntoGameRoom(const std::string & accountID, const std::string & roomID); void deletePlayerFromGameRoom(const std::string & accountID, const std::string & roomID); @@ -75,21 +74,19 @@ public: void insertAccessCookie(const std::string & accountID, const std::string & accessCookieUUID); void insertChatMessage(const std::string & sender, const std::string & roomType, const std::string & roomID, const std::string & messageText); - void updateAccessCookie(const std::string & accountID, const std::string & accessCookieUUID); void updateAccountLoginTime(const std::string & accountID); - void updateActiveAccount(const std::string & accountID, bool isActive); std::vector getActiveGameRooms(); std::vector getActiveAccounts(); - std::vector getAccountsInRoom(const std::string & roomID); +// std::vector getAccountsInRoom(const std::string & roomID); std::vector getRecentMessageHistory(); std::string getIdleGameRoom(const std::string & hostAccountID); std::string getAccountGameRoom(const std::string & accountID); std::string getAccountDisplayName(const std::string & accountID); - LobbyCookieStatus getGameRoomCookieStatus(const std::string & accountID, const std::string & accessCookieUUID, std::chrono::seconds cookieLifetime); - LobbyCookieStatus getAccountCookieStatus(const std::string & accountID, const std::string & accessCookieUUID, std::chrono::seconds cookieLifetime); +// LobbyCookieStatus getGameRoomCookieStatus(const std::string & accountID, const std::string & accessCookieUUID); + LobbyCookieStatus getAccountCookieStatus(const std::string & accountID, const std::string & accessCookieUUID); LobbyInviteStatus getAccountInviteStatus(const std::string & accountID, const std::string & roomID); LobbyRoomState getGameRoomStatus(const std::string & roomID); uint32_t getGameRoomFreeSlots(const std::string & roomID); diff --git a/lobby/LobbyDefines.h b/lobby/LobbyDefines.h index 710d9277d..b2b323f65 100644 --- a/lobby/LobbyDefines.h +++ b/lobby/LobbyDefines.h @@ -36,7 +36,6 @@ struct LobbyChatMessage enum class LobbyCookieStatus : int32_t { INVALID, - EXPIRED, VALID }; diff --git a/lobby/LobbyServer.cpp b/lobby/LobbyServer.cpp index 551425025..e02f64366 100644 --- a/lobby/LobbyServer.cpp +++ b/lobby/LobbyServer.cpp @@ -17,8 +17,6 @@ #include #include -static const auto accountCookieLifetime = std::chrono::hours(24 * 7); - bool LobbyServer::isAccountNameValid(const std::string & accountName) const { if(accountName.size() < 4) @@ -60,7 +58,7 @@ NetworkConnectionPtr LobbyServer::findGameRoom(const std::string & gameRoomID) c void LobbyServer::sendMessage(const NetworkConnectionPtr & target, const JsonNode & json) { - networkServer->sendPacket(target, json.toBytes(true)); + target->sendPacket(json.toBytes(true)); } void LobbyServer::sendAccountCreated(const NetworkConnectionPtr & target, const std::string & accountID, const std::string & accountCookie) @@ -206,16 +204,27 @@ void LobbyServer::onNewConnection(const NetworkConnectionPtr & connection) void LobbyServer::onDisconnected(const NetworkConnectionPtr & connection, const std::string & errorMessage) { if(activeAccounts.count(connection)) + { database->setAccountOnline(activeAccounts.at(connection), false); + activeAccounts.erase(connection); + } if(activeGameRooms.count(connection)) + { database->setGameRoomStatus(activeGameRooms.at(connection), LobbyRoomState::CLOSED); + activeGameRooms.erase(connection); + } - // NOTE: lost connection can be in only one of these lists (or in none of them) - // calling on all possible containers since calling std::map::erase() with non-existing key is legal - activeAccounts.erase(connection); - activeProxies.erase(connection); - activeGameRooms.erase(connection); + if(activeProxies.count(connection)) + { + auto & otherConnection = activeProxies.at(connection); + + if (otherConnection) + otherConnection->close(); + + activeProxies.erase(connection); + activeProxies.erase(otherConnection); + } broadcastActiveAccounts(); broadcastActiveGameRooms(); @@ -226,7 +235,7 @@ void LobbyServer::onPacketReceived(const NetworkConnectionPtr & connection, cons // proxy connection - no processing, only redirect if(activeProxies.count(connection)) { - auto lockedPtr = activeProxies.at(connection).lock(); + auto lockedPtr = activeProxies.at(connection); if(lockedPtr) return lockedPtr->sendPacket(message); @@ -296,9 +305,7 @@ void LobbyServer::onPacketReceived(const NetworkConnectionPtr & connection, cons if(messageType == "serverProxyLogin") return receiveServerProxyLogin(connection, json); - // TODO: add logging of suspicious connections. - networkServer->closeConnection(connection); - + connection->close(); logGlobal->info("(unauthorised): Unknown message type %s", messageType); } @@ -348,13 +355,11 @@ void LobbyServer::receiveClientLogin(const NetworkConnectionPtr & connection, co if(!database->isAccountIDExists(accountID)) return sendOperationFailed(connection, "Account not found"); - auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie, accountCookieLifetime); + auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie); if(clientCookieStatus == LobbyCookieStatus::INVALID) return sendOperationFailed(connection, "Authentification failure"); - // prolong existing cookie - database->updateAccessCookie(accountID, accountCookie); database->updateAccountLoginTime(accountID); database->setAccountOnline(accountID, true); @@ -365,8 +370,8 @@ void LobbyServer::receiveClientLogin(const NetworkConnectionPtr & connection, co sendLoginSuccess(connection, accountCookie, displayName); sendChatHistory(connection, database->getRecentMessageHistory()); - // send active accounts list to new account - // and update acount list to everybody else + // send active game rooms list to new account + // and update acount list to everybody else including new account broadcastActiveAccounts(); sendMessage(connection, prepareActiveGameRooms()); } @@ -378,7 +383,7 @@ void LobbyServer::receiveServerLogin(const NetworkConnectionPtr & connection, co std::string accountCookie = json["accountCookie"].String(); std::string version = json["version"].String(); - auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie, accountCookieLifetime); + auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie); if(clientCookieStatus == LobbyCookieStatus::INVALID) { @@ -399,7 +404,7 @@ void LobbyServer::receiveClientProxyLogin(const NetworkConnectionPtr & connectio std::string accountID = json["accountID"].String(); std::string accountCookie = json["accountCookie"].String(); - auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie, accountCookieLifetime); + auto clientCookieStatus = database->getAccountCookieStatus(accountID, accountCookie); if(clientCookieStatus != LobbyCookieStatus::INVALID) { @@ -424,7 +429,7 @@ void LobbyServer::receiveClientProxyLogin(const NetworkConnectionPtr & connectio } sendOperationFailed(connection, "Invalid credentials"); - networkServer->closeConnection(connection); + connection->close(); } void LobbyServer::receiveServerProxyLogin(const NetworkConnectionPtr & connection, const JsonNode & json) @@ -456,7 +461,7 @@ void LobbyServer::receiveServerProxyLogin(const NetworkConnectionPtr & connectio return; } - //networkServer->closeConnection(connection); + //connection->close(); } void LobbyServer::receiveOpenGameRoom(const NetworkConnectionPtr & connection, const JsonNode & json) @@ -480,9 +485,6 @@ void LobbyServer::receiveOpenGameRoom(const NetworkConnectionPtr & connection, c if(roomType == "private") database->setGameRoomStatus(gameRoomID, LobbyRoomState::PRIVATE); - // TODO: additional flags / initial settings, e.g. allowCheats - // TODO: connection mode: direct or proxy. For now direct is assumed. Proxy might be needed later, for hosted servers - database->insertPlayerIntoGameRoom(accountID, gameRoomID); broadcastActiveGameRooms(); sendJoinRoomSuccess(connection, gameRoomID, false); diff --git a/lobby/LobbyServer.h b/lobby/LobbyServer.h index 6a3ebc5fa..e61b0111c 100644 --- a/lobby/LobbyServer.h +++ b/lobby/LobbyServer.h @@ -29,7 +29,7 @@ class LobbyServer final : public INetworkServerListener }; /// list of connected proxies. All messages received from (key) will be redirected to (value) connection - std::map activeProxies; + std::map activeProxies; /// list of half-established proxies from server that are still waiting for client to connect std::vector awaitingProxies; diff --git a/server/CVCMIServer.cpp b/server/CVCMIServer.cpp index 6fe6ee669..9f170b33d 100644 --- a/server/CVCMIServer.cpp +++ b/server/CVCMIServer.cpp @@ -160,7 +160,7 @@ void CVCMIServer::onNewConnection(const std::shared_ptr & co } else { - networkServer->closeConnection(connection); + connection->close(); } } @@ -445,7 +445,7 @@ void CVCMIServer::clientConnected(std::shared_ptr c, std::vector c) { - networkServer->closeConnection(c->getConnection()); + c->getConnection()->close(); vstd::erase(activeConnections, c); if(activeConnections.empty() || hostClientId == c->connectionID) diff --git a/server/GlobalLobbyProcessor.cpp b/server/GlobalLobbyProcessor.cpp index 23a5037a7..3bafe24f5 100644 --- a/server/GlobalLobbyProcessor.cpp +++ b/server/GlobalLobbyProcessor.cpp @@ -29,7 +29,15 @@ void GlobalLobbyProcessor::establishNewConnection() void GlobalLobbyProcessor::onDisconnected(const std::shared_ptr & connection, const std::string & errorMessage) { - throw std::runtime_error("Lost connection to a lobby server!"); + if (connection == controlConnection) + { + throw std::runtime_error("Lost connection to a lobby server!"); + } + else + { + // player disconnected + owner.onDisconnected(connection, errorMessage); + } } void GlobalLobbyProcessor::onPacketReceived(const std::shared_ptr & connection, const std::vector & message) @@ -47,7 +55,7 @@ void GlobalLobbyProcessor::onPacketReceived(const std::shared_ptrerror("Received unexpected message from lobby server of type '%s' ", json["type"].String()); } else {