File:  [ELWIX - Embedded LightWeight unIX -] / embedaddon / strongswan / src / libtls / tls_socket.c
Revision 1.1.1.2 (vendor branch): download - view: text, annotated - select for diffs - revision graph
Wed Mar 17 00:20:09 2021 UTC (3 years, 7 months ago) by misho
Branches: strongswan, MAIN
CVS tags: v5_9_2p0, HEAD
strongswan 5.9.2

/*
 * Copyright (C) 2010 Martin Willi
 * Copyright (C) 2010 revosec AG
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 */

#include "tls_socket.h"

#include <unistd.h>
#include <errno.h>

#include <utils/debug.h>
#include <threading/thread.h>

/**
 * Buffer size for plain side I/O
 */
#define PLAIN_BUF_SIZE	TLS_MAX_FRAGMENT_LEN

/**
 * Buffer size for encrypted side I/O
 */
#define CRYPTO_BUF_SIZE	TLS_MAX_FRAGMENT_LEN + 2048

typedef struct private_tls_socket_t private_tls_socket_t;
typedef struct private_tls_application_t private_tls_application_t;

struct private_tls_application_t {

	/**
	 * Implements tls_application layer.
	 */
	tls_application_t application;

	/**
	 * Output buffer to write to
	 */
	chunk_t out;

	/**
	 * Number of bytes written to out
	 */
	size_t out_done;

	/**
	 * Input buffer to read to
	 */
	chunk_t in;

	/**
	 * Number of bytes read to in
	 */
	size_t in_done;

	/**
	 * Cached input data
	 */
	chunk_t cache;

	/**
	 * Bytes consumed in cache
	 */
	size_t cache_done;

	/**
	 * Close TLS connection?
	 */
	bool close;
};

/**
 * Private data of an tls_socket_t object.
 */
struct private_tls_socket_t {

	/**
	 * Public tls_socket_t interface.
	 */
	tls_socket_t public;

	/**
	 * TLS application implementation
	 */
	private_tls_application_t app;

	/**
	 * TLS stack
	 */
	tls_t *tls;

	/**
	 * Underlying OS socket
	 */
	int fd;

	/**
	 * Whether the socket returned EOF
	 */
	bool eof;
};

METHOD(tls_application_t, process, status_t,
	private_tls_application_t *this, bio_reader_t *reader)
{
	chunk_t data;
	size_t len;

	if (this->close)
	{
		return SUCCESS;
	}
	len = min(reader->remaining(reader), this->in.len - this->in_done);
	if (len)
	{	/* copy to read buffer as much as fits in */
		if (!reader->read_data(reader, len, &data))
		{
			return FAILED;
		}
		memcpy(this->in.ptr + this->in_done, data.ptr, data.len);
		this->in_done += data.len;
	}
	else
	{	/* read buffer is full, cache for next read */
		if (!reader->read_data(reader, reader->remaining(reader), &data))
		{
			return FAILED;
		}
		this->cache = chunk_cat("mc", this->cache, data);
	}
	return NEED_MORE;
}

METHOD(tls_application_t, build, status_t,
	private_tls_application_t *this, bio_writer_t *writer)
{
	if (this->close)
	{
		return SUCCESS;
	}
	if (this->out.len > this->out_done)
	{
		writer->write_data(writer, this->out);
		this->out_done = this->out.len;
		return NEED_MORE;
	}
	return INVALID_STATE;
}

/**
 * TLS data exchange loop
 */
static bool exchange(private_tls_socket_t *this, bool wr, bool block)
{
	char buf[CRYPTO_BUF_SIZE], *pos;
	ssize_t in, out;
	size_t len;
	int flags;

	while (TRUE)
	{
		while (TRUE)
		{
			len = sizeof(buf);
			switch (this->tls->build(this->tls, buf, &len, NULL))
			{
				case NEED_MORE:
				case ALREADY_DONE:
					pos = buf;
					while (len)
					{
						out = write(this->fd, pos, len);
						if (out == -1)
						{
							DBG1(DBG_TLS, "TLS crypto write error: %s",
								 strerror(errno));
							return FALSE;
						}
						len -= out;
						pos += out;
					}
					continue;
				case INVALID_STATE:
					break;
				case SUCCESS:
					return TRUE;
				default:
					if (wr)
					{
						return FALSE;
					}
					break;
			}
			break;
		}
		if (wr)
		{
			if (this->app.out_done == this->app.out.len)
			{	/* all data written */
				return TRUE;
			}
		}
		else
		{
			if (this->app.in_done == this->app.in.len)
			{	/* buffer fully received */
				return TRUE;
			}
		}

		flags = 0;
		if (this->app.out_done == this->app.out.len)
		{
			if (!block || this->app.in_done)
			{
				flags |= MSG_DONTWAIT;
			}
		}
		in = recv(this->fd, buf, sizeof(buf), flags);
		if (in < 0)
		{
			if (errno == EAGAIN || errno == EWOULDBLOCK)
			{
				if (this->app.in_done == 0)
				{
					/* reading, nothing got yet, and call would block */
					errno = EWOULDBLOCK;
					this->app.in_done = -1;
				}
				return TRUE;
			}
			return FALSE;
		}
		if (in == 0)
		{	/* EOF */
			this->eof = TRUE;
			return TRUE;
		}
		switch (this->tls->process(this->tls, buf, in))
		{
			case NEED_MORE:
				break;
			case SUCCESS:
				return TRUE;
			default:
				return FALSE;
		}
	}
}

METHOD(tls_socket_t, read_, ssize_t,
	private_tls_socket_t *this, void *buf, size_t len, bool block)
{
	if (this->app.cache.len)
	{
		size_t cache;

		cache = min(len, this->app.cache.len - this->app.cache_done);
		memcpy(buf, this->app.cache.ptr + this->app.cache_done, cache);

		this->app.cache_done += cache;
		if (this->app.cache_done == this->app.cache.len)
		{
			chunk_free(&this->app.cache);
			this->app.cache_done = 0;
		}
		return cache;
	}
	if (this->eof)
	{
		return 0;
	}
	this->app.in.ptr = buf;
	this->app.in.len = len;
	this->app.in_done = 0;
	if (exchange(this, FALSE, block))
	{
		if (!this->app.in_done && !this->eof)
		{
			errno = EWOULDBLOCK;
			return -1;
		}
		return this->app.in_done;
	}
	return -1;
}

METHOD(tls_socket_t, write_, ssize_t,
	private_tls_socket_t *this, void *buf, size_t len)
{
	this->app.out.ptr = buf;
	this->app.out.len = len;
	this->app.out_done = 0;
	if (exchange(this, TRUE, FALSE))
	{
		return this->app.out_done;
	}
	return -1;
}

METHOD(tls_socket_t, splice, bool,
	private_tls_socket_t *this, int rfd, int wfd)
{
	char buf[PLAIN_BUF_SIZE], *pos;
	ssize_t in, out;
	bool old, crypto_eof = FALSE;
	struct pollfd pfd[] = {
		{ .fd = this->fd,	.events = POLLIN, },
		{ .fd = rfd,		.events = POLLIN, },
	};

	while (!this->eof && !crypto_eof)
	{
		old = thread_cancelability(TRUE);
		in = poll(pfd, countof(pfd), -1);
		thread_cancelability(old);
		if (in == -1)
		{
			DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
			return FALSE;
		}
		while (!this->eof && pfd[0].revents & (POLLIN | POLLHUP | POLLNVAL))
		{
			in = read_(this, buf, sizeof(buf), FALSE);
			switch (in)
			{
				case -1:
					if (errno != EWOULDBLOCK)
					{
						DBG1(DBG_TLS, "TLS read error: %s", strerror(errno));
						return FALSE;
					}
					break;
				default:
					pos = buf;
					while (in)
					{
						out = write(wfd, pos, in);
						if (out == -1)
						{
							DBG1(DBG_TLS, "TLS plain write error: %s",
								 strerror(errno));
							return FALSE;
						}
						in -= out;
						pos += out;
					}
					continue;
			}
			break;
		}
		if (!crypto_eof && pfd[1].revents & (POLLIN | POLLHUP | POLLNVAL))
		{
			in = read(rfd, buf, sizeof(buf));
			switch (in)
			{
				case 0:
					crypto_eof = TRUE;
					break;
				case -1:
					DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
					return FALSE;
				default:
					pos = buf;
					while (in)
					{
						out = write_(this, pos, in);
						if (out == -1)
						{
							DBG1(DBG_TLS, "TLS write error");
							return FALSE;
						}
						in -= out;
						pos += out;
					}
					break;
			}
		}
	}
	return TRUE;
}

METHOD(tls_socket_t, get_fd, int,
	private_tls_socket_t *this)
{
	return this->fd;
}

METHOD(tls_socket_t, get_server_id, identification_t*,
	private_tls_socket_t *this)
{
	return this->tls->get_server_id(this->tls);
}

METHOD(tls_socket_t, get_peer_id, identification_t*,
	private_tls_socket_t *this)
{
	return this->tls->get_peer_id(this->tls);
}

METHOD(tls_socket_t, destroy, void,
	private_tls_socket_t *this)
{
	/* send a TLS close notify if not done yet */
	this->app.close = TRUE;
	write_(this, NULL, 0);
	free(this->app.cache.ptr);
	this->tls->destroy(this->tls);
	free(this);
}

/**
 * See header
 */
tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
								identification_t *peer, int fd,
								tls_cache_t *cache, tls_version_t min_version,
								tls_version_t max_version, tls_flag_t flags)
{
	private_tls_socket_t *this;

	INIT(this,
		.public = {
			.read = _read_,
			.write = _write_,
			.splice = _splice,
			.get_fd = _get_fd,
			.get_server_id = _get_server_id,
			.get_peer_id = _get_peer_id,
			.destroy = _destroy,
		},
		.app = {
			.application = {
				.build = _build,
				.process = _process,
				.destroy = (void*)nop,
			},
		},
		.fd = fd,
	);

	this->tls = tls_create(is_server, server, peer, TLS_PURPOSE_GENERIC,
						   &this->app.application, cache, flags);
	if (!this->tls ||
		!this->tls->set_version(this->tls, min_version, max_version))
	{
		free(this);
		return NULL;
	}
	return &this->public;
}

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>