1
0
mirror of https://github.com/pgbackrest/pgbackrest.git synced 2024-12-12 10:04:14 +02:00

Split session functionality of TlsClient out into TlsSession.

This abstraction allows the session code to be shared between the TLS client and (upcoming) server code.

Session management is no longer implemented in TlsClient so the HttpClient was updated to free and create sessions as needed. No test changes were required for HttpClient so the functionality should be unchanged.

Mechanical changes to the TLS tests were required to use TlsSession where appropriate rather than TlsClient. There should be no change in functionality other than how sessions are managed, i.e. using tlsClientOpen()/tlsSessionFree() rather than just tlsClientOpen().
This commit is contained in:
David Steele 2020-04-14 15:02:18 -04:00
parent c9481bb95f
commit 9f2d647bad
14 changed files with 541 additions and 412 deletions

View File

@ -42,6 +42,14 @@
<p>Split session functionality of <code>SocketClient</code> out into <code>SocketSession</code>.</p>
</release-item>
<release-item>
<release-item-contributor-list>
<release-item-reviewer id="cynthia.shang"/>
</release-item-contributor-list>
<p>Split session functionality of <code>TlsClient</code> out into <code>TlsSession</code>.</p>
</release-item>
<release-item>
<release-item-contributor-list>
<release-item-reviewer id="cynthia.shang"/>

View File

@ -78,6 +78,7 @@ SRCS = \
common/io/socket/common.c \
common/io/socket/session.c \
common/io/tls/client.c \
common/io/tls/session.c \
common/io/write.c \
common/ini.c \
common/lock.c \

View File

@ -55,7 +55,8 @@ struct HttpClient
MemContext *memContext; // Mem context
TimeMSec timeout; // Request timeout
TlsClient *tls; // Tls client
TlsClient *tlsClient; // TLS client
TlsSession *tlsSession; // Current TLS session
IoRead *ioRead; // Read io interface
unsigned int responseCode; // Response code (e.g. 200, 404)
@ -97,8 +98,8 @@ httpClientRead(THIS_VOID, Buffer *buffer, bool block)
// If close was requested and no content specified then the server may send content up until the eof
if (this->closeOnContentEof && !this->contentChunked && this->contentSize == 0)
{
ioRead(tlsClientIoRead(this->tls), buffer);
this->contentEof = ioReadEof(tlsClientIoRead(this->tls));
ioRead(tlsSessionIoRead(this->tlsSession), buffer);
this->contentEof = ioReadEof(tlsSessionIoRead(this->tlsSession));
}
// Else read using specified encoding or size
else
@ -111,7 +112,8 @@ httpClientRead(THIS_VOID, Buffer *buffer, bool block)
// Read length of next chunk
MEM_CONTEXT_TEMP_BEGIN()
{
this->contentRemaining = cvtZToUInt64Base(strPtr(strTrim(ioReadLine(tlsClientIoRead(this->tls)))), 16);
this->contentRemaining = cvtZToUInt64Base(
strPtr(strTrim(ioReadLine(tlsSessionIoRead(this->tlsSession)))), 16);
}
MEM_CONTEXT_TEMP_END();
@ -130,10 +132,10 @@ httpClientRead(THIS_VOID, Buffer *buffer, bool block)
bufLimitSet(buffer, bufSize(buffer) - (bufRemains(buffer) - (size_t)this->contentRemaining));
actualBytes = bufRemains(buffer);
this->contentRemaining -= ioRead(tlsClientIoRead(this->tls), buffer);
this->contentRemaining -= ioRead(tlsSessionIoRead(this->tlsSession), buffer);
// Error if EOF but content read is not complete
if (ioReadEof(tlsClientIoRead(this->tls)))
if (ioReadEof(tlsSessionIoRead(this->tlsSession)))
THROW(FileReadError, "unexpected EOF reading HTTP content");
// Clear limit (this works even if the limit was not set and it is easier than checking)
@ -147,7 +149,7 @@ httpClientRead(THIS_VOID, Buffer *buffer, bool block)
// around to check.
if (this->contentChunked)
{
ioReadLine(tlsClientIoRead(this->tls));
ioReadLine(tlsSessionIoRead(this->tlsSession));
}
// If total content size was provided then this is eof
else
@ -159,7 +161,10 @@ httpClientRead(THIS_VOID, Buffer *buffer, bool block)
// If the server notified that it would close the connection after sending content then close the client side
if (this->contentEof && this->closeOnContentEof)
tlsClientClose(this->tls);
{
tlsSessionFree(this->tlsSession);
this->tlsSession = NULL;
}
}
FUNCTION_LOG_RETURN(SIZE, (size_t)actualBytes);
@ -208,7 +213,7 @@ httpClientNew(
{
.memContext = MEM_CONTEXT_NEW(),
.timeout = timeout,
.tls = tlsClientNew(sckClientNew(host, port, timeout), timeout, verifyPeer, caFile, caPath),
.tlsClient = tlsClientNew(sckClientNew(host, port, timeout), timeout, verifyPeer, caFile, caPath),
};
httpClientStatLocal.object++;
@ -270,14 +275,21 @@ httpClientRequest(
TRY_BEGIN()
{
if (tlsClientOpen(this->tls))
httpClientStatLocal.session++;
if (this->tlsSession == NULL)
{
MEM_CONTEXT_BEGIN(this->memContext)
{
this->tlsSession = tlsClientOpen(this->tlsClient);
httpClientStatLocal.session++;
}
MEM_CONTEXT_END();
}
// Write the request
String *queryStr = httpQueryRender(query);
ioWriteStrLine(
tlsClientIoWrite(this->tls),
tlsSessionIoWrite(this->tlsSession),
strNewFmt(
"%s %s%s%s " HTTP_VERSION "\r", strPtr(verb), strPtr(httpUriEncode(uri, true)), queryStr == NULL ? "" : "?",
queryStr == NULL ? "" : strPtr(queryStr)));
@ -291,23 +303,23 @@ httpClientRequest(
{
const String *headerKey = strLstGet(headerList, headerIdx);
ioWriteStrLine(
tlsClientIoWrite(this->tls),
tlsSessionIoWrite(this->tlsSession),
strNewFmt("%s:%s\r", strPtr(headerKey), strPtr(httpHeaderGet(requestHeader, headerKey))));
}
}
// Write out blank line to end the headers
ioWriteLine(tlsClientIoWrite(this->tls), CR_BUF);
ioWriteLine(tlsSessionIoWrite(this->tlsSession), CR_BUF);
// Write out body if any
if (body != NULL)
ioWrite(tlsClientIoWrite(this->tls), body);
ioWrite(tlsSessionIoWrite(this->tlsSession), body);
// Flush all writes
ioWriteFlush(tlsClientIoWrite(this->tls));
ioWriteFlush(tlsSessionIoWrite(this->tlsSession));
// Read status and make sure it starts with the correct http version
String *status = strTrim(ioReadLine(tlsClientIoRead(this->tls)));
String *status = strTrim(ioReadLine(tlsSessionIoRead(this->tlsSession)));
if (!strBeginsWith(status, HTTP_VERSION_STR))
THROW_FMT(FormatError, "http version of response '%s' must be " HTTP_VERSION, strPtr(status));
@ -338,7 +350,7 @@ httpClientRequest(
do
{
// Read the next header
String *header = strTrim(ioReadLine(tlsClientIoRead(this->tls)));
String *header = strTrim(ioReadLine(tlsSessionIoRead(this->tlsSession)));
// If the header is empty then we have reached the end of the headers
if (strSize(header) == 0)
@ -429,7 +441,10 @@ httpClientRequest(
// If the server notified that it would close the connection and there is no content then close the client side
if (this->closeOnContentEof && !contentExists)
tlsClientClose(this->tls);
{
tlsSessionFree(this->tlsSession);
this->tlsSession = NULL;
}
// Retry when response code is 5xx. These errors generally represent a server error for a request that looks valid.
// There are a few errors that might be permanently fatal but they are rare and it seems best not to try and pick
@ -451,7 +466,8 @@ httpClientRequest(
httpClientStatLocal.retry++;
}
tlsClientClose(this->tls);
tlsSessionFree(this->tlsSession);
this->tlsSession = NULL;
}
TRY_END();
}
@ -504,7 +520,10 @@ httpClientDone(HttpClient *this)
{
// If it looks like we were in the middle of a response then close the TLS session so we can start clean next time
if (!this->contentEof)
tlsClientClose(this->tls);
{
tlsSessionFree(this->tlsSession);
this->tlsSession = NULL;
}
ioReadFree(this->ioRead);
this->ioRead = NULL;

View File

@ -36,6 +36,9 @@ struct SocketClient
OBJECT_DEFINE_MOVE(SOCKET_CLIENT);
OBJECT_DEFINE_GET(Host, const, SOCKET_CLIENT, const String *, host);
OBJECT_DEFINE_GET(Port, const, SOCKET_CLIENT, unsigned int, port);
/**********************************************************************************************************************************/
SocketClient *
sckClientNew(const String *host, unsigned int port, TimeMSec timeout)

View File

@ -49,6 +49,15 @@ SocketClient *sckClientMove(SocketClient *this, MemContext *parentNew);
// Statistics as a formatted string
String *sckClientStatStr(void);
/***********************************************************************************************************************************
Getters/Setters
***********************************************************************************************************************************/
// Socket host
const String *sckClientHost(const SocketClient *this);
// Socket port
unsigned int sckClientPort(const SocketClient *this);
/***********************************************************************************************************************************
Macros for function logging
***********************************************************************************************************************************/

View File

@ -29,9 +29,9 @@ struct SocketSession
TimeMSec timeout; // Timeout for any i/o operation (connect, read, etc.)
};
OBJECT_DEFINE_MOVE(SOCKET_SESSION);
OBJECT_DEFINE_GET(Fd, , SOCKET_SESSION, int, fd);
OBJECT_DEFINE_GET(Host, const, SOCKET_SESSION, const String *, host);
OBJECT_DEFINE_GET(Port, const, SOCKET_SESSION, unsigned int, port);
OBJECT_DEFINE_FREE(SOCKET_SESSION);

View File

@ -27,6 +27,9 @@ SocketSession *sckSessionNew(int fd, const String *host, unsigned int port, Time
/***********************************************************************************************************************************
Functions
***********************************************************************************************************************************/
// Move to a new parent mem context
SocketSession *sckSessionMove(SocketSession *this, MemContext *parentNew);
// Wait for the socket to be readable
void sckSessionReadWait(SocketSession *this);
@ -36,12 +39,6 @@ Getters/Setters
// Socket file descriptor
int sckSessionFd(SocketSession *this);
// Socket host
const String *sckSessionHost(const SocketSession *this);
// Socket port
unsigned int sckSessionPort(const SocketSession *this);
/***********************************************************************************************************************************
Destructor
***********************************************************************************************************************************/

View File

@ -6,17 +6,14 @@ TLS Client
#include <string.h>
#include <strings.h>
#include <openssl/conf.h>
#include <openssl/ssl.h>
#include <openssl/x509v3.h>
#include "common/crypto/common.h"
#include "common/debug.h"
#include "common/log.h"
#include "common/io/tls/client.h"
#include "common/io/io.h"
#include "common/io/read.intern.h"
#include "common/io/write.intern.h"
#include "common/io/tls/client.h"
#include "common/io/tls/session.intern.h"
#include "common/memContext.h"
#include "common/type/object.h"
#include "common/wait.h"
@ -36,17 +33,9 @@ struct TlsClient
bool verifyPeer; // Should the peer (server) certificate be verified?
SocketClient *socketClient; // Socket client
SocketSession *socketSession; // Socket session
SSL_CTX *context; // TLS context
SSL *session; // TLS session on the socket
IoRead *read; // Read interface
IoWrite *write; // Write interface
};
OBJECT_DEFINE_GET(IoRead, , TLS_CLIENT, IoRead *, read);
OBJECT_DEFINE_GET(IoWrite, , TLS_CLIENT, IoWrite *, write);
OBJECT_DEFINE_FREE(TLS_CLIENT);
/***********************************************************************************************************************************
@ -54,62 +43,10 @@ Free connection
***********************************************************************************************************************************/
OBJECT_DEFINE_FREE_RESOURCE_BEGIN(TLS_CLIENT, LOG, logLevelTrace)
{
SSL_free(this->session);
SSL_CTX_free(this->context);
}
OBJECT_DEFINE_FREE_RESOURCE_END(LOG);
/***********************************************************************************************************************************
Report TLS errors. Returns true if the command should continue and false if it should exit.
***********************************************************************************************************************************/
static bool
tlsError(TlsClient *this, int code)
{
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_CLIENT, this);
FUNCTION_LOG_PARAM(INT, code);
FUNCTION_LOG_END();
bool result = false;
switch (code)
{
// The connection was closed
case SSL_ERROR_ZERO_RETURN:
{
tlsClientClose(this);
break;
}
// Try the read/write again
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
{
result = true;
break;
}
// A syscall failed (this usually indicates eof)
case SSL_ERROR_SYSCALL:
{
// Get the error before closing so it is not cleared
int errNo = errno;
tlsClientClose(this);
// Throw the sys error if there is one
THROW_ON_SYS_ERROR(errNo, KernelError, "tls failed syscall");
break;
}
// Some other tls error that cannot be handled
default:
THROW_FMT(ServiceError, "tls error [%d]", code);
}
FUNCTION_LOG_RETURN(BOOL, result);
}
/**********************************************************************************************************************************/
TlsClient *
tlsClientNew(SocketClient *socket, TimeMSec timeout, bool verifyPeer, const String *caFile, const String *caPath)
@ -317,174 +254,10 @@ tlsClientHostVerify(const String *host, X509 *certificate)
FUNCTION_LOG_RETURN(BOOL, result);
}
/***********************************************************************************************************************************
Read from the TLS session
***********************************************************************************************************************************/
size_t
tlsClientRead(THIS_VOID, Buffer *buffer, bool block)
{
THIS(TlsClient);
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_CLIENT, this);
FUNCTION_LOG_PARAM(BUFFER, buffer);
FUNCTION_LOG_PARAM(BOOL, block);
FUNCTION_LOG_END();
ASSERT(this != NULL);
ASSERT(this->session != NULL);
ASSERT(buffer != NULL);
ASSERT(!bufFull(buffer));
ssize_t result = 0;
// If blocking read keep reading until buffer is full
do
{
// If no tls data pending then check the socket
if (!SSL_pending(this->session))
sckSessionReadWait(this->socketSession);
// Read and handle errors
result = SSL_read(this->session, bufRemainsPtr(buffer), (int)bufRemains(buffer));
if (result <= 0)
{
// Break if the error indicates that we should not continue trying
if (!tlsError(this, SSL_get_error(this->session, (int)result)))
break;
}
// Update amount of buffer used
else
bufUsedInc(buffer, (size_t)result);
}
while (block && bufRemains(buffer) > 0);
FUNCTION_LOG_RETURN(SIZE, (size_t)result);
}
/***********************************************************************************************************************************
Write to the tls session
***********************************************************************************************************************************/
static bool
tlsWriteContinue(TlsClient *this, int writeResult, int writeError, size_t writeSize)
{
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_CLIENT, this);
FUNCTION_LOG_PARAM(INT, writeResult);
FUNCTION_LOG_PARAM(INT, writeError);
FUNCTION_LOG_PARAM(SIZE, writeSize);
FUNCTION_LOG_END();
ASSERT(this != NULL);
ASSERT(writeSize > 0);
bool result = true;
// Handle errors
if (writeResult <= 0)
{
// If error = SSL_ERROR_NONE then this is the first write attempt so continue
if (writeError != SSL_ERROR_NONE)
{
// Error if the error indicates that we should not continue trying
if (!tlsError(this, writeError))
THROW_FMT(FileWriteError, "unable to write to tls [%d]", writeError);
// Wait for the socket to be readable for tls renegotiation
sckSessionReadWait(this->socketSession);
}
}
else
{
if ((size_t)writeResult != writeSize)
{
THROW_FMT(
FileWriteError, "unable to write to tls, write size %d does not match expected size %zu", writeResult, writeSize);
}
result = false;
}
FUNCTION_LOG_RETURN(BOOL, result);
}
void
tlsClientWrite(THIS_VOID, const Buffer *buffer)
{
THIS(TlsClient);
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_CLIENT, this);
FUNCTION_LOG_PARAM(BUFFER, buffer);
FUNCTION_LOG_END();
ASSERT(this != NULL);
ASSERT(this->session != NULL);
ASSERT(buffer != NULL);
int result = 0;
int error = SSL_ERROR_NONE;
while (tlsWriteContinue(this, result, error, bufUsed(buffer)))
{
result = SSL_write(this->session, bufPtrConst(buffer), (int)bufUsed(buffer));
error = SSL_get_error(this->session, result);
}
FUNCTION_LOG_RETURN_VOID();
}
/***********************************************************************************************************************************
Close the connection
***********************************************************************************************************************************/
void
tlsClientClose(TlsClient *this)
{
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_CLIENT, this);
FUNCTION_LOG_END();
ASSERT(this != NULL);
// Close the socket
if (this->socketSession != NULL)
{
sckSessionFree(this->socketSession);
this->socketSession = NULL;
}
// Free the TLS session
if (this->session != NULL)
{
SSL_free(this->session);
this->session = NULL;
}
FUNCTION_LOG_RETURN_VOID();
}
/***********************************************************************************************************************************
Has session been closed by the server?
***********************************************************************************************************************************/
bool
tlsClientEof(THIS_VOID)
{
THIS(TlsClient);
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_CLIENT, this);
FUNCTION_LOG_END();
ASSERT(this != NULL);
FUNCTION_LOG_RETURN(BOOL, this->session == NULL);
}
/***********************************************************************************************************************************
Open connection if this is a new client or if the connection was closed by the server
***********************************************************************************************************************************/
bool
TlsSession *
tlsClientOpen(TlsClient *this)
{
FUNCTION_LOG_BEGIN(logLevelTrace)
@ -493,114 +266,92 @@ tlsClientOpen(TlsClient *this)
ASSERT(this != NULL);
bool result = false;
TlsSession *result = NULL;
SSL *session = NULL;
if (this->session == NULL)
MEM_CONTEXT_TEMP_BEGIN()
{
// Free the read/write interfaces
ioReadFree(this->read);
this->read = NULL;
ioWriteFree(this->write);
this->write = NULL;
bool connected = false;
bool retry;
Wait *wait = waitNew(this->timeout);
MEM_CONTEXT_TEMP_BEGIN()
do
{
bool connected = false;
bool retry;
Wait *wait = waitNew(this->timeout);
// Assume there will be no retry
retry = false;
do
TRY_BEGIN()
{
// Assume there will be no retry
retry = false;
// Create internal TLS session
cryptoError((session = SSL_new(this->context)) == NULL, "unable to create TLS session");
TRY_BEGIN()
// Set server host name used for validation
cryptoError(
SSL_set_tlsext_host_name(session, strPtr(sckClientHost(this->socketClient))) != 1,
"unable to set TLS host name");
// Create the TLS session
result = tlsSessionNew(session, sckClientOpen(this->socketClient), this->timeout);
// Connection was successful
connected = true;
}
CATCH_ANY()
{
tlsSessionFree(result);
result = NULL;
// Retry if wait time has not expired
if (waitMore(wait))
{
// Open the socket
MEM_CONTEXT_BEGIN(this->memContext)
{
this->socketSession = sckClientOpen(this->socketClient);
}
MEM_CONTEXT_END();
LOG_DEBUG_FMT("retry %s: %s", errorTypeName(errorType()), errorMessage());
retry = true;
// Negotiate TLS
cryptoError((this->session = SSL_new(this->context)) == NULL, "unable to create TLS context");
cryptoError(
SSL_set_tlsext_host_name(this->session, strPtr(sckSessionHost(this->socketSession))) != 1,
"unable to set TLS host name");
cryptoError(
SSL_set_fd(this->session, sckSessionFd(this->socketSession)) != 1, "unable to add socket to TLS context");
cryptoError(SSL_connect(this->session) != 1, "unable to negotiate TLS connection");
// Connection was successful
connected = true;
tlsClientStatLocal.retry++;
}
CATCH_ANY()
{
// Retry if wait time has not expired
if (waitMore(wait))
{
LOG_DEBUG_FMT("retry %s: %s", errorTypeName(errorType()), errorMessage());
retry = true;
tlsClientStatLocal.retry++;
}
tlsClientClose(this);
}
TRY_END();
}
while (!connected && retry);
if (!connected)
RETHROW();
TRY_END();
}
MEM_CONTEXT_TEMP_END();
while (!connected && retry);
// Verify that the certificate presented by the server is valid
if (this->verifyPeer)
if (!connected)
RETHROW();
tlsSessionMove(result, memContextPrior());
}
MEM_CONTEXT_TEMP_END();
tlsClientStatLocal.session++;
// Verify that the certificate presented by the server is valid
if (this->verifyPeer)
{
// Verify that the chain of trust leads to a valid CA
long int verifyResult = SSL_get_verify_result(session);
if (verifyResult != X509_V_OK)
{
// Verify that the chain of trust leads to a valid CA
long int verifyResult = SSL_get_verify_result(this->session);
if (verifyResult != X509_V_OK)
{
THROW_FMT(
CryptoError, "unable to verify certificate presented by '%s:%u': [%ld] %s",
strPtr(sckSessionHost(this->socketSession)), sckSessionPort(this->socketSession), verifyResult,
X509_verify_cert_error_string(verifyResult));
}
// Verify that the hostname appears in the certificate
X509 *certificate = SSL_get_peer_certificate(this->session);
bool nameResult = tlsClientHostVerify(sckSessionHost(this->socketSession), certificate);
X509_free(certificate);
if (!nameResult)
{
THROW_FMT(
CryptoError,
"unable to find hostname '%s' in certificate common name or subject alternative names",
strPtr(sckSessionHost(this->socketSession)));
}
THROW_FMT(
CryptoError, "unable to verify certificate presented by '%s:%u': [%ld] %s",
strPtr(sckClientHost(this->socketClient)), sckClientPort(this->socketClient), verifyResult,
X509_verify_cert_error_string(verifyResult));
}
MEM_CONTEXT_BEGIN(this->memContext)
// Verify that the hostname appears in the certificate
X509 *certificate = SSL_get_peer_certificate(session);
bool nameResult = tlsClientHostVerify(sckClientHost(this->socketClient), certificate);
X509_free(certificate);
if (!nameResult)
{
// Create read and write interfaces
this->write = ioWriteNewP(this, .write = tlsClientWrite);
ioWriteOpen(this->write);
this->read = ioReadNewP(this, .block = true, .eof = tlsClientEof, .read = tlsClientRead);
ioReadOpen(this->read);
THROW_FMT(
CryptoError,
"unable to find hostname '%s' in certificate common name or subject alternative names",
strPtr(sckClientHost(this->socketClient)));
}
MEM_CONTEXT_END();
tlsClientStatLocal.session++;
result = true;
}
FUNCTION_LOG_RETURN(BOOL, result);
FUNCTION_LOG_RETURN(TLS_SESSION, result);
}
/**********************************************************************************************************************************/

View File

@ -1,17 +1,10 @@
/***********************************************************************************************************************************
TLS Client
A simple, secure TLS client intended to allow access to services that are exposed via HTTPS. We call it TLS instead of SSL because
A simple, secure TLS client intended to allow access to services that are exposed via HTTPS. We call it TLS instead of SSL because
SSL methods are disabled so only TLS connections are allowed.
This object is intended to be used for multiple TLS connections against a service so tlsClientOpen() can be called each time a new
connection is needed. By default, an open connection will be reused so the user must be prepared to retry their transaction on a
read/write error if the server closes the connection before it is reused. If this behavior is not desirable then tlsClientClose()
may be used to ensure that the next call to tlsClientOpen() will create a new TLS session.
Note that tlsClientRead() is non-blocking unless there are *zero* bytes to be read from the session in which case it will raise an
error after the defined timeout. In any case the tlsClientRead()/tlsClientWrite()/tlsClientEof() functions should not generally
be called directly. Instead use the read/write interfaces available from tlsClientIoRead()/tlsClientIoWrite().
This object is intended to be used for multiple TLS sessions so tlsClientOpen() can be called each time a new session is needed.
***********************************************************************************************************************************/
#ifndef COMMON_IO_TLS_CLIENT_H
#define COMMON_IO_TLS_CLIENT_H
@ -25,10 +18,7 @@ Object type
typedef struct TlsClient TlsClient;
#include "common/io/socket/client.h"
#include "common/io/read.h"
#include "common/io/write.h"
#include "common/time.h"
#include "common/type/string.h"
#include "common/io/tls/session.h"
/***********************************************************************************************************************************
Statistics
@ -48,20 +38,8 @@ TlsClient *tlsClientNew(SocketClient *socket, TimeMSec timeout, bool verifyPeer,
/***********************************************************************************************************************************
Functions
***********************************************************************************************************************************/
// Open tls connection
bool tlsClientOpen(TlsClient *this);
// Close tls connection
void tlsClientClose(TlsClient *this);
/***********************************************************************************************************************************
Getters/Setters
***********************************************************************************************************************************/
// Read interface
IoRead *tlsClientIoRead(TlsClient *this);
// Write interface
IoWrite *tlsClientIoWrite(TlsClient *this);
// Open tls session
TlsSession *tlsClientOpen(TlsClient *this);
// Statistics as a formatted string
String *tlsClientStatStr(void);

303
src/common/io/tls/session.c Normal file
View File

@ -0,0 +1,303 @@
/***********************************************************************************************************************************
TLS Session
***********************************************************************************************************************************/
#include "build.auto.h"
#include "common/crypto/common.h"
#include "common/debug.h"
#include "common/io/io.h"
#include "common/io/read.intern.h"
#include "common/io/tls/session.intern.h"
#include "common/io/write.intern.h"
#include "common/log.h"
#include "common/memContext.h"
#include "common/type/object.h"
/***********************************************************************************************************************************
Object type
***********************************************************************************************************************************/
struct TlsSession
{
MemContext *memContext; // Mem context
SocketSession *socketSession; // Socket session
SSL *session; // TLS session on the socket
TimeMSec timeout; // Timeout for any i/o operation (connect, read, etc.)
IoRead *read; // Read interface
IoWrite *write; // Write interface
};
OBJECT_DEFINE_MOVE(TLS_SESSION);
OBJECT_DEFINE_GET(IoRead, , TLS_SESSION, IoRead *, read);
OBJECT_DEFINE_GET(IoWrite, , TLS_SESSION, IoWrite *, write);
OBJECT_DEFINE_FREE(TLS_SESSION);
/***********************************************************************************************************************************
Free connection
***********************************************************************************************************************************/
OBJECT_DEFINE_FREE_RESOURCE_BEGIN(TLS_SESSION, LOG, logLevelTrace)
{
SSL_free(this->session);
}
OBJECT_DEFINE_FREE_RESOURCE_END(LOG);
/***********************************************************************************************************************************
Close the connection
***********************************************************************************************************************************/
static void
tlsSessionClose(TlsSession *this)
{
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_SESSION, this);
FUNCTION_LOG_END();
ASSERT(this != NULL);
// If not already closed
if (this->session != NULL)
{
// Free the socket session
sckSessionFree(this->socketSession);
this->socketSession = NULL;
// Free the TLS session
memContextCallbackClear(this->memContext);
tlsSessionFreeResource(this);
this->session = NULL;
}
FUNCTION_LOG_RETURN_VOID();
}
/***********************************************************************************************************************************
Report TLS errors. Returns true if the command should continue and false if it should exit.
***********************************************************************************************************************************/
static bool
tlsSessionError(TlsSession *this, int code)
{
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_SESSION, this);
FUNCTION_LOG_PARAM(INT, code);
FUNCTION_LOG_END();
bool result = false;
switch (code)
{
// The connection was closed
case SSL_ERROR_ZERO_RETURN:
{
tlsSessionClose(this);
break;
}
// Try the read/write again
case SSL_ERROR_WANT_READ:
case SSL_ERROR_WANT_WRITE:
{
result = true;
break;
}
// A syscall failed (this usually indicates eof)
case SSL_ERROR_SYSCALL:
{
// Get the error before closing so it is not cleared
int errNo = errno;
tlsSessionClose(this);
// Throw the sys error if there is one
THROW_ON_SYS_ERROR(errNo, KernelError, "tls failed syscall");
break;
}
// Some other tls error that cannot be handled
default:
THROW_FMT(ServiceError, "tls error [%d]", code);
}
FUNCTION_LOG_RETURN(BOOL, result);
}
/***********************************************************************************************************************************
Read from the TLS session
***********************************************************************************************************************************/
static size_t
tlsSessionRead(THIS_VOID, Buffer *buffer, bool block)
{
THIS(TlsSession);
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_SESSION, this);
FUNCTION_LOG_PARAM(BUFFER, buffer);
FUNCTION_LOG_PARAM(BOOL, block);
FUNCTION_LOG_END();
ASSERT(this != NULL);
ASSERT(this->session != NULL);
ASSERT(buffer != NULL);
ASSERT(!bufFull(buffer));
ssize_t result = 0;
// If blocking read keep reading until buffer is full
do
{
// If no tls data pending then check the socket
if (!SSL_pending(this->session))
sckSessionReadWait(this->socketSession);
// Read and handle errors
result = SSL_read(this->session, bufRemainsPtr(buffer), (int)bufRemains(buffer));
if (result <= 0)
{
// Break if the error indicates that we should not continue trying
if (!tlsSessionError(this, SSL_get_error(this->session, (int)result)))
break;
}
// Update amount of buffer used
else
bufUsedInc(buffer, (size_t)result);
}
while (block && bufRemains(buffer) > 0);
FUNCTION_LOG_RETURN(SIZE, (size_t)result);
}
/***********************************************************************************************************************************
Write to the tls session
***********************************************************************************************************************************/
static bool
tlsSessionWriteContinue(TlsSession *this, int writeResult, int writeError, size_t writeSize)
{
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_SESSION, this);
FUNCTION_LOG_PARAM(INT, writeResult);
FUNCTION_LOG_PARAM(INT, writeError);
FUNCTION_LOG_PARAM(SIZE, writeSize);
FUNCTION_LOG_END();
ASSERT(this != NULL);
ASSERT(writeSize > 0);
bool result = true;
// Handle errors
if (writeResult <= 0)
{
// If error = SSL_ERROR_NONE then this is the first write attempt so continue
if (writeError != SSL_ERROR_NONE)
{
// Error if the error indicates that we should not continue trying
if (!tlsSessionError(this, writeError))
THROW_FMT(FileWriteError, "unable to write to tls [%d]", writeError);
// Wait for the socket to be readable for tls renegotiation
sckSessionReadWait(this->socketSession);
}
}
else
{
if ((size_t)writeResult != writeSize)
{
THROW_FMT(
FileWriteError, "unable to write to tls, write size %d does not match expected size %zu", writeResult, writeSize);
}
result = false;
}
FUNCTION_LOG_RETURN(BOOL, result);
}
static void
tlsSessionWrite(THIS_VOID, const Buffer *buffer)
{
THIS(TlsSession);
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_SESSION, this);
FUNCTION_LOG_PARAM(BUFFER, buffer);
FUNCTION_LOG_END();
ASSERT(this != NULL);
ASSERT(this->session != NULL);
ASSERT(buffer != NULL);
int result = 0;
int error = SSL_ERROR_NONE;
while (tlsSessionWriteContinue(this, result, error, bufUsed(buffer)))
{
result = SSL_write(this->session, bufPtrConst(buffer), (int)bufUsed(buffer));
error = SSL_get_error(this->session, result);
}
FUNCTION_LOG_RETURN_VOID();
}
/***********************************************************************************************************************************
Has session been closed by the server?
***********************************************************************************************************************************/
static bool
tlsSessionEof(THIS_VOID)
{
THIS(TlsSession);
FUNCTION_LOG_BEGIN(logLevelTrace);
FUNCTION_LOG_PARAM(TLS_SESSION, this);
FUNCTION_LOG_END();
ASSERT(this != NULL);
FUNCTION_LOG_RETURN(BOOL, this->session == NULL);
}
/**********************************************************************************************************************************/
TlsSession *
tlsSessionNew(SSL *session, SocketSession *socketSession, TimeMSec timeout)
{
FUNCTION_LOG_BEGIN(logLevelDebug)
FUNCTION_LOG_PARAM_P(VOID, session);
FUNCTION_LOG_PARAM(SOCKET_SESSION, socketSession);
FUNCTION_LOG_PARAM(TIME_MSEC, timeout);
FUNCTION_LOG_END();
ASSERT(session != NULL);
ASSERT(socketSession != NULL);
TlsSession *this = NULL;
MEM_CONTEXT_NEW_BEGIN("TlsSession")
{
this = memNew(sizeof(TlsSession));
*this = (TlsSession)
{
.memContext = MEM_CONTEXT_NEW(),
.session = session,
.socketSession = sckSessionMove(socketSession, MEM_CONTEXT_NEW()),
.timeout = timeout,
};
// Initiate TLS connection
cryptoError(
SSL_set_fd(this->session, sckSessionFd(this->socketSession)) != 1, "unable to add socket to TLS session");
cryptoError(SSL_connect(this->session) != 1, "unable to negotiate TLS connection");
memContextCallbackSet(this->memContext, tlsSessionFreeResource, this);
// Create read and write interfaces
this->write = ioWriteNewP(this, .write = tlsSessionWrite);
ioWriteOpen(this->write);
this->read = ioReadNewP(this, .block = true, .eof = tlsSessionEof, .read = tlsSessionRead);
ioReadOpen(this->read);
}
MEM_CONTEXT_NEW_END();
FUNCTION_LOG_RETURN(TLS_SESSION, this);
}

View File

@ -0,0 +1,53 @@
/***********************************************************************************************************************************
TLS Session
TLS sessions are created by calling tlsClientOpen().
TLS sessions are generally reused so the user must be prepared to retry their transaction on a read/write error if the server closes
the session before it is reused. If this behavior is not desirable then tlsSessionFree()/tlsClientOpen() can be called to get a new
session.
***********************************************************************************************************************************/
#ifndef COMMON_IO_TLS_SESSION_H
#define COMMON_IO_TLS_SESSION_H
/***********************************************************************************************************************************
Object type
***********************************************************************************************************************************/
#define TLS_SESSION_TYPE TlsSession
#define TLS_SESSION_PREFIX tlsSession
typedef struct TlsSession TlsSession;
#include "common/io/read.h"
#include "common/io/socket/session.h"
#include "common/io/write.h"
/***********************************************************************************************************************************
Functions
***********************************************************************************************************************************/
// Move to a new parent mem context
TlsSession *tlsSessionMove(TlsSession *this, MemContext *parentNew);
/***********************************************************************************************************************************
Getters/Setters
***********************************************************************************************************************************/
// Read interface
IoRead *tlsSessionIoRead(TlsSession *this);
// Write interface
IoWrite *tlsSessionIoWrite(TlsSession *this);
/***********************************************************************************************************************************
Destructor
***********************************************************************************************************************************/
void tlsSessionFree(TlsSession *this);
/***********************************************************************************************************************************
Macros for function logging
***********************************************************************************************************************************/
#define FUNCTION_LOG_TLS_SESSION_TYPE \
TlsSession *
#define FUNCTION_LOG_TLS_SESSION_FORMAT(value, buffer, bufferSize) \
objToLog(value, "TlsSession", buffer, bufferSize)
#endif

View File

@ -0,0 +1,17 @@
/***********************************************************************************************************************************
TLS Session Internal
***********************************************************************************************************************************/
#ifndef COMMON_IO_TLS_SESSION_INTERN_H
#define COMMON_IO_TLS_SESSION_INTERN_H
#include <openssl/ssl.h>
#include "common/io/tls/session.h"
/***********************************************************************************************************************************
Constructors
***********************************************************************************************************************************/
// Only called by TLS client/server code
TlsSession *tlsSessionNew(SSL *session, SocketSession *socketSession, TimeMSec timeout);
#endif

View File

@ -239,11 +239,12 @@ unit:
# ----------------------------------------------------------------------------------------------------------------------------
- name: io-tls
total: 5
total: 4
containerReq: true
coverage:
common/io/tls/client: full
common/io/tls/session: full
common/io/socket/client: full
common/io/socket/common: full
common/io/socket/session: full

View File

@ -226,21 +226,6 @@ testRun(void)
TEST_RESULT_BOOL(tlsClientHostVerifyName(strNew("a.bogus.host.com"), strNew("*.host.com")), false, "invalid host");
}
// Additional coverage not provided by other tests
// *****************************************************************************************************************************
if (testBegin("tlsError()"))
{
TlsClient *client = NULL;
TEST_ASSIGN(
client, tlsClientNew(sckClientNew(strNew("99.99.99.99.99"), harnessTlsTestPort(), 0), 0, true, NULL, NULL),
"new client");
TEST_RESULT_BOOL(tlsError(client, SSL_ERROR_WANT_READ), true, "continue after want read");
TEST_RESULT_BOOL(tlsError(client, SSL_ERROR_ZERO_RETURN), false, "check connection closed error");
TEST_ERROR(tlsError(client, SSL_ERROR_WANT_X509_LOOKUP), ServiceError, "tls error [4]");
}
// *****************************************************************************************************************************
if (testBegin("TlsClient verification"))
{
@ -346,6 +331,7 @@ testRun(void)
if (testBegin("TlsClient general usage"))
{
TlsClient *client = NULL;
TlsSession *session = NULL;
// Reset statistics
sckClientStatLocal = (SocketClientStat){0};
@ -369,53 +355,56 @@ testRun(void)
client,
tlsClientNew(sckClientNew(harnessTlsTestHost(), harnessTlsTestPort(), 500), 500, testContainer(), NULL, NULL),
"new client");
TEST_RESULT_VOID(tlsClientOpen(client), "open client");
TEST_ASSIGN(session, tlsClientOpen(client), "open client");
const Buffer *input = BUFSTRDEF("some protocol info");
TEST_RESULT_VOID(ioWrite(tlsClientIoWrite(client), input), "write input");
ioWriteFlush(tlsClientIoWrite(client));
TEST_RESULT_VOID(ioWrite(tlsSessionIoWrite(session), input), "write input");
ioWriteFlush(tlsSessionIoWrite(session));
TEST_RESULT_STR_Z(ioReadLine(tlsClientIoRead(client)), "something:0", "read line");
TEST_RESULT_BOOL(ioReadEof(tlsClientIoRead(client)), false, " check eof = false");
TEST_RESULT_STR_Z(ioReadLine(tlsSessionIoRead(session)), "something:0", "read line");
TEST_RESULT_BOOL(ioReadEof(tlsSessionIoRead(session)), false, " check eof = false");
Buffer *output = bufNew(12);
TEST_RESULT_UINT(ioRead(tlsClientIoRead(client), output), 12, "read output");
TEST_RESULT_UINT(ioRead(tlsSessionIoRead(session), output), 12, "read output");
TEST_RESULT_STR_Z(strNewBuf(output), "some content", " check output");
TEST_RESULT_BOOL(ioReadEof(tlsClientIoRead(client)), false, " check eof = false");
TEST_RESULT_BOOL(ioReadEof(tlsSessionIoRead(session)), false, " check eof = false");
output = bufNew(8);
TEST_RESULT_UINT(ioRead(tlsClientIoRead(client), output), 8, "read output");
TEST_RESULT_UINT(ioRead(tlsSessionIoRead(session), output), 8, "read output");
TEST_RESULT_STR_Z(strNewBuf(output), "AND MORE", " check output");
TEST_RESULT_BOOL(ioReadEof(tlsClientIoRead(client)), false, " check eof = false");
TEST_RESULT_BOOL(ioReadEof(tlsSessionIoRead(session)), false, " check eof = false");
output = bufNew(12);
TEST_ERROR_FMT(
ioRead(tlsClientIoRead(client), output), FileReadError,
ioRead(tlsSessionIoRead(session), output), FileReadError,
"timeout after 500ms waiting for read from '%s:%u'", strPtr(harnessTlsTestHost()), harnessTlsTestPort());
// -----------------------------------------------------------------------------------------------------------------
input = BUFSTRDEF("more protocol info");
TEST_RESULT_VOID(tlsClientOpen(client), "open client again (it is already open)");
TEST_RESULT_VOID(ioWrite(tlsClientIoWrite(client), input), "write input");
ioWriteFlush(tlsClientIoWrite(client));
TEST_RESULT_VOID(ioWrite(tlsSessionIoWrite(session), input), "write input");
ioWriteFlush(tlsSessionIoWrite(session));
output = bufNew(12);
TEST_RESULT_UINT(ioRead(tlsClientIoRead(client), output), 12, "read output");
TEST_RESULT_UINT(ioRead(tlsSessionIoRead(session), output), 12, "read output");
TEST_RESULT_STR_Z(strNewBuf(output), "0123456789AB", " check output");
TEST_RESULT_BOOL(ioReadEof(tlsClientIoRead(client)), false, " check eof = false");
TEST_RESULT_BOOL(ioReadEof(tlsSessionIoRead(session)), false, " check eof = false");
output = bufNew(12);
TEST_RESULT_UINT(ioRead(tlsClientIoRead(client), output), 0, "read no output after eof");
TEST_RESULT_BOOL(ioReadEof(tlsClientIoRead(client)), true, " check eof = true");
TEST_RESULT_UINT(ioRead(tlsSessionIoRead(session), output), 0, "read no output after eof");
TEST_RESULT_BOOL(ioReadEof(tlsSessionIoRead(session)), true, " check eof = true");
TEST_RESULT_VOID(tlsSessionClose(session), "close again");
TEST_ERROR(tlsSessionError(session, SSL_ERROR_WANT_X509_LOOKUP), ServiceError, "tls error [4]");
// -----------------------------------------------------------------------------------------------------------------
TEST_RESULT_VOID(tlsClientOpen(client), "open client again (was closed by server)");
TEST_RESULT_BOOL(tlsWriteContinue(client, -1, SSL_ERROR_WANT_READ, 1), true, "continue on WANT_READ");
TEST_RESULT_BOOL(tlsWriteContinue(client, 0, SSL_ERROR_NONE, 1), true, "continue on WANT_READ");
TEST_ASSIGN(session, tlsClientOpen(client), "open client again (was closed by server)");
TEST_RESULT_BOOL(tlsSessionWriteContinue(session, -1, SSL_ERROR_WANT_READ, 1), true, "continue on WANT_READ");
TEST_RESULT_BOOL(tlsSessionWriteContinue(session, 0, SSL_ERROR_NONE, 1), true, "continue on WANT_READ");
TEST_ERROR(
tlsWriteContinue(client, 77, 0, 88), FileWriteError,
tlsSessionWriteContinue(session, 77, 0, 88), FileWriteError,
"unable to write to tls, write size 77 does not match expected size 88");
TEST_ERROR(tlsWriteContinue(client, 0, SSL_ERROR_ZERO_RETURN, 1), FileWriteError, "unable to write to tls [6]");
TEST_ERROR(
tlsSessionWriteContinue(session, 0, SSL_ERROR_ZERO_RETURN, 1), FileWriteError, "unable to write to tls [6]");
// -----------------------------------------------------------------------------------------------------------------
TEST_RESULT_BOOL(sckClientStatStr() != NULL, true, "check statistics exist");