From 2260a7512aab7a50473466cc524aabb2a5002edf Mon Sep 17 00:00:00 2001 From: David Steele Date: Thu, 16 Apr 2020 15:02:33 -0400 Subject: [PATCH] Use poll() instead of select() for monitoring socket read/write ready. select() is a bit old-fashioned and cumbersome to use. Since the select() code needed to be modified to handle write ready this seems like a good time to upgrade to poll(). poll() has been around for a long time so there doesn't seem to be any need to provide a fallback to select(). Also change the error on timeout from FileReadError to ProtocolError. This works better for read vs. write and failure to poll() is indicative of a protocol error or unexpected EOF. --- doc/xml/release.xml | 10 +++ src/common/io/socket/common.c | 105 ++++++++++++++++++++++++++++ src/common/io/socket/common.h | 7 ++ src/common/io/socket/session.c | 47 ++++++------- src/common/io/socket/session.h | 5 +- src/common/io/tls/session.c | 4 +- test/src/module/common/ioHttpTest.c | 2 +- test/src/module/common/ioTlsTest.c | 56 +++++++++++++-- 8 files changed, 199 insertions(+), 37 deletions(-) diff --git a/doc/xml/release.xml b/doc/xml/release.xml index 1d3299312..f4dcaba8a 100644 --- a/doc/xml/release.xml +++ b/doc/xml/release.xml @@ -45,6 +45,16 @@

TCP keep-alive options are configurable.

+ + + + + + + +

Use poll() instead of select() for monitoring socket read/write ready.

+
+ diff --git a/src/common/io/socket/common.c b/src/common/io/socket/common.c index 89b7c179b..2c6b854d7 100644 --- a/src/common/io/socket/common.c +++ b/src/common/io/socket/common.c @@ -6,12 +6,14 @@ Socket Common Functions #include #include #include +#include #include #include #include "common/debug.h" #include "common/io/socket/common.h" #include "common/log.h" +#include "common/wait.h" /*********************************************************************************************************************************** Local variables @@ -121,3 +123,106 @@ sckOptionSet(int fd) FUNCTION_TEST_RETURN_VOID(); } + +/*********************************************************************************************************************************** +Use poll() to determine when data is ready to read/write on a socket. Retry after EINTR with whatever time is left on the timer. +***********************************************************************************************************************************/ +// Helper to determine when poll() should be retried +static bool +sckReadyRetry(int pollResult, int errNo, bool first, TimeMSec *timeout, TimeMSec timeEnd) +{ + FUNCTION_TEST_BEGIN(); + FUNCTION_TEST_PARAM(INT, pollResult); + FUNCTION_TEST_PARAM(INT, errNo); + FUNCTION_TEST_PARAM(BOOL, first); + FUNCTION_TEST_PARAM_P(TIME_MSEC, timeout); + FUNCTION_TEST_PARAM(TIME_MSEC, timeEnd); + FUNCTION_TEST_END(); + + ASSERT(timeout != NULL); + + // No retry by default + bool result = false; + + // Process errors + if (pollResult == -1) + { + // Don't error on an interrupt. If the interrupt lasts long enough there may be a timeout, though. + if (errNo != EINTR) + THROW_SYS_ERROR_CODE(errNo, KernelError, "unable to poll socket"); + + // Always retry on the first iteration + if (first) + { + result = true; + } + // Else retry if there is time left + else + { + TimeMSec timeCurrent = timeMSec(); + + if (timeEnd > timeCurrent) + { + *timeout = timeEnd - timeCurrent; + result = true; + } + } + } + + FUNCTION_TEST_RETURN(result); +} + +bool +sckReady(int fd, bool read, bool write, TimeMSec timeout) +{ + FUNCTION_LOG_BEGIN(logLevelTrace); + FUNCTION_LOG_PARAM(INT, fd); + FUNCTION_LOG_PARAM(BOOL, read); + FUNCTION_LOG_PARAM(BOOL, write); + FUNCTION_LOG_PARAM(TIME_MSEC, timeout); + FUNCTION_LOG_END(); + + ASSERT(fd >= 0); + ASSERT(read || write); + ASSERT(timeout < INT_MAX); + + // Poll settings + struct pollfd inputFd = {.fd = fd}; + + if (read) + inputFd.events |= POLLIN; + + if (write) + inputFd.events |= POLLOUT; + + // Wait for ready or timeout + TimeMSec timeEnd = timeMSec() + timeout; + bool first = true; + + // Initialize result and errno to look like a retryable error. We have no good way to test this function with interrupts so this + // at least ensures that the condition is retried. + int result = -1; + int errNo = EINTR; + + while (sckReadyRetry(result, errNo, first, &timeout, timeEnd)) + { + result = poll(&inputFd, 1, (int)timeout); + + errNo = errno; + first = false; + } + + FUNCTION_LOG_RETURN(BOOL, result > 0); +} + +bool +sckReadyRead(int fd, TimeMSec timeout) +{ + return sckReady(fd, true, false, timeout); +} + +bool +sckReadyWrite(int fd, TimeMSec timeout) +{ + return sckReady(fd, false, true, timeout); +} diff --git a/src/common/io/socket/common.h b/src/common/io/socket/common.h index 98b392699..b122eb424 100644 --- a/src/common/io/socket/common.h +++ b/src/common/io/socket/common.h @@ -4,6 +4,8 @@ Socket Common Functions #ifndef COMMON_IO_SOCKET_COMMON_H #define COMMON_IO_SOCKET_COMMON_H +#include "common/time.h" + /*********************************************************************************************************************************** Functions ***********************************************************************************************************************************/ @@ -13,4 +15,9 @@ void sckInit(bool keepAlive, int tcpKeepAliveCount, int tcpKeepAliveIdle, int tc // Set options on a socket void sckOptionSet(int fd); +// Wait until the socket is ready to read/write or timeout +bool sckReady(int fd, bool read, bool write, TimeMSec timeout); +bool sckReadyRead(int fd, TimeMSec timeout); +bool sckReadyWrite(int fd, TimeMSec timeout); + #endif diff --git a/src/common/io/socket/session.c b/src/common/io/socket/session.c index 86053b8ae..de1b800b9 100644 --- a/src/common/io/socket/session.c +++ b/src/common/io/socket/session.c @@ -3,10 +3,6 @@ Socket Session ***********************************************************************************************************************************/ #include "build.auto.h" -#include -#include -#include -#include #include #include "common/debug.h" @@ -86,36 +82,37 @@ sckSessionNew(SocketSessionType type, int fd, const String *host, unsigned int p /**********************************************************************************************************************************/ void -sckSessionReadWait(SocketSession *this) +sckSessionReadyRead(SocketSession *this) { FUNCTION_LOG_BEGIN(logLevelTrace); FUNCTION_LOG_PARAM(SOCKET_SESSION, this); FUNCTION_LOG_END(); ASSERT(this != NULL); - ASSERT(this->fd != -1); - // Initialize the file descriptor set used for select - fd_set selectSet; - FD_ZERO(&selectSet); - - // We know the socket is not negative because it passed error handling, so it is safe to cast to unsigned - FD_SET((unsigned int)this->fd, &selectSet); - - // Initialize timeout struct used for select. Recreate this structure each time since Linux (at least) will modify it. - struct timeval timeoutSelect; - timeoutSelect.tv_sec = (time_t)(this->timeout / MSEC_PER_SEC); - timeoutSelect.tv_usec = (time_t)(this->timeout % MSEC_PER_SEC * 1000); - - // Determine if there is data to be read - int result = select(this->fd + 1, &selectSet, NULL, NULL, &timeoutSelect); - THROW_ON_SYS_ERROR_FMT(result == -1, AssertError, "unable to select from '%s:%u'", strPtr(this->host), this->port); - - // If no data available after time allotted then error - if (!result) + if (!sckReadyRead(this->fd, this->timeout)) { THROW_FMT( - FileReadError, "timeout after %" PRIu64 "ms waiting for read from '%s:%u'", this->timeout, strPtr(this->host), + ProtocolError, "timeout after %" PRIu64 "ms waiting for read from '%s:%u'", this->timeout, strPtr(this->host), + this->port); + } + + FUNCTION_LOG_RETURN_VOID(); +} + +void +sckSessionReadyWrite(SocketSession *this) +{ + FUNCTION_LOG_BEGIN(logLevelTrace); + FUNCTION_LOG_PARAM(SOCKET_SESSION, this); + FUNCTION_LOG_END(); + + ASSERT(this != NULL); + + if (!sckReadyWrite(this->fd, this->timeout)) + { + THROW_FMT( + ProtocolError, "timeout after %" PRIu64 "ms waiting for write to '%s:%u'", this->timeout, strPtr(this->host), this->port); } diff --git a/src/common/io/socket/session.h b/src/common/io/socket/session.h index 7a30e96cb..68b541d2c 100644 --- a/src/common/io/socket/session.h +++ b/src/common/io/socket/session.h @@ -39,8 +39,9 @@ Functions // Move to a new parent mem context SocketSession *sckSessionMove(SocketSession *this, MemContext *parentNew); -// Wait for the socket to be readable -void sckSessionReadWait(SocketSession *this); +// Check if there is data ready to read/write on the socket +void sckSessionReadyRead(SocketSession *this); +void sckSessionReadyWrite(SocketSession *this); /*********************************************************************************************************************************** Getters/Setters diff --git a/src/common/io/tls/session.c b/src/common/io/tls/session.c index a89b456f4..8ca2ab44d 100644 --- a/src/common/io/tls/session.c +++ b/src/common/io/tls/session.c @@ -149,7 +149,7 @@ tlsSessionRead(THIS_VOID, Buffer *buffer, bool block) { // If no tls data pending then check the socket if (!SSL_pending(this->session)) - sckSessionReadWait(this->socketSession); + sckSessionReadyRead(this->socketSession); // Read and handle errors result = SSL_read(this->session, bufRemainsPtr(buffer), (int)bufRemains(buffer)); @@ -198,7 +198,7 @@ tlsSessionWriteContinue(TlsSession *this, int writeResult, int writeError, size_ THROW_FMT(FileWriteError, "unable to write to tls [%d]", writeError); // Wait for the socket to be readable for tls renegotiation - sckSessionReadWait(this->socketSession); + sckSessionReadyRead(this->socketSession); } } else diff --git a/test/src/module/common/ioHttpTest.c b/test/src/module/common/ioHttpTest.c index c154cba08..d1f25f195 100644 --- a/test/src/module/common/ioHttpTest.c +++ b/test/src/module/common/ioHttpTest.c @@ -469,7 +469,7 @@ testRun(void) client->timeout = 0; TEST_ERROR_FMT( - httpClientRequest(client, strNew("GET"), strNew("/"), NULL, NULL, NULL, false), FileReadError, + httpClientRequest(client, strNew("GET"), strNew("/"), NULL, NULL, NULL, false), ProtocolError, "timeout after 500ms waiting for read from '%s:%u'", strPtr(harnessTlsTestHost()), harnessTlsTestPort()); // Test invalid http version diff --git a/test/src/module/common/ioTlsTest.c b/test/src/module/common/ioTlsTest.c index 5cf2f5a77..65c104571 100644 --- a/test/src/module/common/ioTlsTest.c +++ b/test/src/module/common/ioTlsTest.c @@ -116,20 +116,21 @@ testRun(void) .ai_protocol = IPPROTO_TCP, }; - struct addrinfo *hostAddress; int result; - const char *host = "127.0.0.1"; const char *port = "7777"; - if ((result = getaddrinfo(host, port, &hints, &hostAddress)) != 0) + const char *hostBad = "172.31.255.255"; + struct addrinfo *hostBadAddress; + + if ((result = getaddrinfo(hostBad, port, &hints, &hostBadAddress)) != 0) { THROW_FMT( // {uncoverable - lookup on IP should never fail} - HostConnectError, "unable to get address for '%s': [%d] %s", host, result, gai_strerror(result)); + HostConnectError, "unable to get address for '%s': [%d] %s", hostBad, result, gai_strerror(result)); } TRY_BEGIN() { - int fd = socket(hostAddress->ai_family, hostAddress->ai_socktype, hostAddress->ai_protocol); + int fd = socket(hostBadAddress->ai_family, hostBadAddress->ai_socktype, hostBadAddress->ai_protocol); THROW_ON_SYS_ERROR(fd == -1, HostConnectError, "unable to create socket"); // --------------------------------------------------------------------------------------------------------------------- @@ -202,11 +203,32 @@ testRun(void) TEST_RESULT_INT(keepAliveCountValue, 32, "check TCP_KEEPCNT"); TEST_RESULT_INT(keepAliveIdleValue, 3113, "check TCP_KEEPIDLE"); TEST_RESULT_INT(keepAliveIntervalValue, 818, "check TCP_KEEPINTVL"); + + // --------------------------------------------------------------------------------------------------------------------- + TEST_TITLE("connect to non-blocking socket to test write ready"); + + // Put the socket in non-blocking mode + int flags; + + THROW_ON_SYS_ERROR((flags = fcntl(fd, F_GETFL)) == -1, ProtocolError, "unable to get flags"); + THROW_ON_SYS_ERROR(fcntl(fd, F_SETFL, flags | O_NONBLOCK) == -1, ProtocolError, "unable to set O_NONBLOCK"); + + // Attempt connection + CHECK(connect(fd, hostBadAddress->ai_addr, hostBadAddress->ai_addrlen) == -1); + + // Create socket session and wait for timeout + SocketSession *session = NULL; + TEST_ASSIGN(session, sckSessionNew(sckSessionTypeClient, fd, strNew(hostBad), 7777, 100), "new socket"); + + TEST_ERROR( + sckSessionReadyWrite(session), ProtocolError, "timeout after 100ms waiting for write to '172.31.255.255:7777'"); + + TEST_RESULT_VOID(sckSessionFree(session), "free socket session"); } FINALLY() { // This needs to be freed or valgrind will complain - freeaddrinfo(hostAddress); + freeaddrinfo(hostBadAddress); } TRY_END(); @@ -362,6 +384,26 @@ testRun(void) "new client"); TEST_ASSIGN(session, tlsClientOpen(client), "open client"); + // ----------------------------------------------------------------------------------------------------------------- + TEST_TITLE("socket read/write ready"); + + TimeMSec timeout = 5757; + TEST_RESULT_BOOL(sckReadyRetry(-1, EINTR, true, &timeout, 0), true, "first retry does not modify timeout"); + TEST_RESULT_UINT(timeout, 5757, " check timeout"); + + timeout = 0; + TEST_RESULT_BOOL(sckReadyRetry(-1, EINTR, false, &timeout, timeMSec() + 10000), true, "retry before timeout"); + TEST_RESULT_BOOL(timeout > 0, true, " check timeout"); + + TEST_RESULT_BOOL(sckReadyRetry(-1, EINTR, false, &timeout, timeMSec()), false, "no retry after timeout"); + TEST_ERROR( + sckReadyRetry(-1, EINVAL, true, &timeout, 0), KernelError, "unable to poll socket: [22] Invalid argument"); + + TEST_RESULT_BOOL(sckReadyRead(session->socketSession->fd, 0), false, "socket is not read ready"); + TEST_RESULT_BOOL(sckReadyWrite(session->socketSession->fd, 100), true, "socket is write ready"); + TEST_RESULT_VOID(sckSessionReadyWrite(session->socketSession), "socket session is write ready"); + + // ----------------------------------------------------------------------------------------------------------------- const Buffer *input = BUFSTRDEF("some protocol info"); TEST_RESULT_VOID(ioWrite(tlsSessionIoWrite(session), input), "write input"); ioWriteFlush(tlsSessionIoWrite(session)); @@ -381,7 +423,7 @@ testRun(void) output = bufNew(12); TEST_ERROR_FMT( - ioRead(tlsSessionIoRead(session), output), FileReadError, + ioRead(tlsSessionIoRead(session), output), ProtocolError, "timeout after 500ms waiting for read from '%s:%u'", strPtr(harnessTlsTestHost()), harnessTlsTestPort()); // -----------------------------------------------------------------------------------------------------------------