1
0
mirror of https://github.com/vcmi/vcmi.git synced 2025-07-05 00:49:09 +02:00

Fix possible memory leak (circular shared_ptr) in networking

This commit is contained in:
Ivan Savenko
2025-04-27 17:06:33 +03:00
parent e567e1b820
commit cd2837a84e
6 changed files with 18 additions and 18 deletions

View File

@ -12,9 +12,9 @@
VCMI_LIB_NAMESPACE_BEGIN
NetworkConnection::NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkSocket> & socket, const std::shared_ptr<NetworkContext> & context)
NetworkConnection::NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkSocket> & socket, NetworkContext & context)
: socket(socket)
, timer(std::make_shared<NetworkTimer>(*context))
, timer(std::make_shared<NetworkTimer>(context))
, listener(listener)
{
socket->set_option(boost::asio::ip::tcp::no_delay(true));
@ -208,7 +208,7 @@ void NetworkConnection::close()
//NOTE: ignoring error code, intended
}
InternalConnection::InternalConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkContext> & context)
InternalConnection::InternalConnection(INetworkConnectionListener & listener, NetworkContext & context)
: io(context)
, listener(listener)
{
@ -216,7 +216,7 @@ InternalConnection::InternalConnection(INetworkConnectionListener & listener, co
void InternalConnection::receivePacket(const std::vector<std::byte> & message)
{
boost::asio::post(*io, [self = std::static_pointer_cast<InternalConnection>(shared_from_this()), message](){
boost::asio::post(io, [self = std::static_pointer_cast<InternalConnection>(shared_from_this()), message](){
if (self->connectionActive)
self->listener.onPacketReceived(self, message);
});
@ -224,7 +224,7 @@ void InternalConnection::receivePacket(const std::vector<std::byte> & message)
void InternalConnection::disconnect()
{
boost::asio::post(*io, [self = std::static_pointer_cast<InternalConnection>(shared_from_this())](){
boost::asio::post(io, [self = std::static_pointer_cast<InternalConnection>(shared_from_this())](){
self->listener.onDisconnected(self, "Internal connection has been terminated");
self->otherSideWeak.reset();
self->connectionActive = false;

View File

@ -38,7 +38,7 @@ class NetworkConnection final : public INetworkConnection, public std::enable_sh
void onDataSent(const boost::system::error_code & ec);
public:
NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkSocket> & socket, const std::shared_ptr<NetworkContext> & context);
NetworkConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkSocket> & socket, NetworkContext & context);
void start();
void close() override;
@ -49,11 +49,11 @@ public:
class InternalConnection final : public IInternalConnection, public std::enable_shared_from_this<InternalConnection>
{
std::weak_ptr<IInternalConnection> otherSideWeak;
std::shared_ptr<NetworkContext> io;
NetworkContext & io;
INetworkConnectionListener & listener;
bool connectionActive = false;
public:
InternalConnection(INetworkConnectionListener & listener, const std::shared_ptr<NetworkContext> & context);
InternalConnection(INetworkConnectionListener & listener, NetworkContext & context);
void receivePacket(const std::vector<std::byte> & message) override;
void disconnect() override;

View File

@ -21,12 +21,12 @@ std::unique_ptr<INetworkHandler> INetworkHandler::createHandler()
}
NetworkHandler::NetworkHandler()
: io(std::make_shared<NetworkContext>())
: io(std::make_unique<NetworkContext>())
{}
std::unique_ptr<INetworkServer> NetworkHandler::createServerTCP(INetworkServerListener & listener)
{
return std::make_unique<NetworkServer>(listener, io);
return std::make_unique<NetworkServer>(listener, *io);
}
void NetworkHandler::connectToRemote(INetworkClientListener & listener, const std::string & host, uint16_t port)
@ -50,7 +50,7 @@ void NetworkHandler::connectToRemote(INetworkClientListener & listener, const st
listener.onConnectionFailed(error.message());
return;
}
auto connection = std::make_shared<NetworkConnection>(listener, socket, io);
auto connection = std::make_shared<NetworkConnection>(listener, socket, *io);
connection->start();
listener.onConnectionEstablished(connection);
@ -75,7 +75,7 @@ void NetworkHandler::createTimer(INetworkTimerListener & listener, std::chrono::
void NetworkHandler::createInternalConnection(INetworkClientListener & listener, INetworkServer & server)
{
auto localConnection = std::make_shared<InternalConnection>(listener, io);
auto localConnection = std::make_shared<InternalConnection>(listener, *io);
server.receiveInternalConnection(localConnection);

View File

@ -15,7 +15,7 @@ VCMI_LIB_NAMESPACE_BEGIN
class NetworkHandler : public INetworkHandler
{
std::shared_ptr<NetworkContext> io;
std::unique_ptr<NetworkContext> io;
public:
NetworkHandler();

View File

@ -13,7 +13,7 @@
VCMI_LIB_NAMESPACE_BEGIN
NetworkServer::NetworkServer(INetworkServerListener & listener, const std::shared_ptr<NetworkContext> & context)
NetworkServer::NetworkServer(INetworkServerListener & listener, NetworkContext & context)
: io(context)
, listener(listener)
{
@ -21,13 +21,13 @@ NetworkServer::NetworkServer(INetworkServerListener & listener, const std::share
uint16_t NetworkServer::start(uint16_t port)
{
acceptor = std::make_shared<NetworkAcceptor>(*io, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), port));
acceptor = std::make_shared<NetworkAcceptor>(io, boost::asio::ip::tcp::endpoint(boost::asio::ip::tcp::v4(), port));
return startAsyncAccept();
}
uint16_t NetworkServer::startAsyncAccept()
{
auto upcomingConnection = std::make_shared<NetworkSocket>(*io);
auto upcomingConnection = std::make_shared<NetworkSocket>(io);
acceptor->async_accept(*upcomingConnection, [this, upcomingConnection](const auto & ec) { connectionAccepted(upcomingConnection, ec); });
return acceptor->local_endpoint().port();
}

View File

@ -15,7 +15,7 @@ VCMI_LIB_NAMESPACE_BEGIN
class NetworkServer : public INetworkConnectionListener, public INetworkServer
{
std::shared_ptr<NetworkContext> io;
NetworkContext & io;
std::shared_ptr<NetworkAcceptor> acceptor;
std::set<std::shared_ptr<INetworkConnection>> connections;
@ -27,7 +27,7 @@ class NetworkServer : public INetworkConnectionListener, public INetworkServer
void onDisconnected(const std::shared_ptr<INetworkConnection> & connection, const std::string & errorMessage) override;
void onPacketReceived(const std::shared_ptr<INetworkConnection> & connection, const std::vector<std::byte> & message) override;
public:
NetworkServer(INetworkServerListener & listener, const std::shared_ptr<NetworkContext> & context);
NetworkServer(INetworkServerListener & listener, NetworkContext & context);
void receiveInternalConnection(std::shared_ptr<IInternalConnection> remoteConnection) override;