Annotation of embedaddon/strongswan/src/libtls/tls_socket.c, revision 1.1.1.1

1.1       misho       1: /*
                      2:  * Copyright (C) 2010 Martin Willi
                      3:  * Copyright (C) 2010 revosec AG
                      4:  *
                      5:  * This program is free software; you can redistribute it and/or modify it
                      6:  * under the terms of the GNU General Public License as published by the
                      7:  * Free Software Foundation; either version 2 of the License, or (at your
                      8:  * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
                      9:  *
                     10:  * This program is distributed in the hope that it will be useful, but
                     11:  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
                     12:  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
                     13:  * for more details.
                     14:  */
                     15: 
                     16: #include "tls_socket.h"
                     17: 
                     18: #include <unistd.h>
                     19: #include <errno.h>
                     20: 
                     21: #include <utils/debug.h>
                     22: #include <threading/thread.h>
                     23: 
                     24: /**
                     25:  * Buffer size for plain side I/O
                     26:  */
                     27: #define PLAIN_BUF_SIZE TLS_MAX_FRAGMENT_LEN
                     28: 
                     29: /**
                     30:  * Buffer size for encrypted side I/O
                     31:  */
                     32: #define CRYPTO_BUF_SIZE        TLS_MAX_FRAGMENT_LEN + 2048
                     33: 
                     34: typedef struct private_tls_socket_t private_tls_socket_t;
                     35: typedef struct private_tls_application_t private_tls_application_t;
                     36: 
                     37: struct private_tls_application_t {
                     38: 
                     39:        /**
                     40:         * Implements tls_application layer.
                     41:         */
                     42:        tls_application_t application;
                     43: 
                     44:        /**
                     45:         * Output buffer to write to
                     46:         */
                     47:        chunk_t out;
                     48: 
                     49:        /**
                     50:         * Number of bytes written to out
                     51:         */
                     52:        size_t out_done;
                     53: 
                     54:        /**
                     55:         * Input buffer to read to
                     56:         */
                     57:        chunk_t in;
                     58: 
                     59:        /**
                     60:         * Number of bytes read to in
                     61:         */
                     62:        size_t in_done;
                     63: 
                     64:        /**
                     65:         * Cached input data
                     66:         */
                     67:        chunk_t cache;
                     68: 
                     69:        /**
                     70:         * Bytes consumed in cache
                     71:         */
                     72:        size_t cache_done;
                     73: 
                     74:        /**
                     75:         * Close TLS connection?
                     76:         */
                     77:        bool close;
                     78: };
                     79: 
                     80: /**
                     81:  * Private data of an tls_socket_t object.
                     82:  */
                     83: struct private_tls_socket_t {
                     84: 
                     85:        /**
                     86:         * Public tls_socket_t interface.
                     87:         */
                     88:        tls_socket_t public;
                     89: 
                     90:        /**
                     91:         * TLS application implementation
                     92:         */
                     93:        private_tls_application_t app;
                     94: 
                     95:        /**
                     96:         * TLS stack
                     97:         */
                     98:        tls_t *tls;
                     99: 
                    100:        /**
                    101:         * Underlying OS socket
                    102:         */
                    103:        int fd;
                    104: };
                    105: 
                    106: METHOD(tls_application_t, process, status_t,
                    107:        private_tls_application_t *this, bio_reader_t *reader)
                    108: {
                    109:        chunk_t data;
                    110:        size_t len;
                    111: 
                    112:        if (this->close)
                    113:        {
                    114:                return SUCCESS;
                    115:        }
                    116:        len = min(reader->remaining(reader), this->in.len - this->in_done);
                    117:        if (len)
                    118:        {       /* copy to read buffer as much as fits in */
                    119:                if (!reader->read_data(reader, len, &data))
                    120:                {
                    121:                        return FAILED;
                    122:                }
                    123:                memcpy(this->in.ptr + this->in_done, data.ptr, data.len);
                    124:                this->in_done += data.len;
                    125:        }
                    126:        else
                    127:        {       /* read buffer is full, cache for next read */
                    128:                if (!reader->read_data(reader, reader->remaining(reader), &data))
                    129:                {
                    130:                        return FAILED;
                    131:                }
                    132:                this->cache = chunk_cat("mc", this->cache, data);
                    133:        }
                    134:        return NEED_MORE;
                    135: }
                    136: 
                    137: METHOD(tls_application_t, build, status_t,
                    138:        private_tls_application_t *this, bio_writer_t *writer)
                    139: {
                    140:        if (this->close)
                    141:        {
                    142:                return SUCCESS;
                    143:        }
                    144:        if (this->out.len > this->out_done)
                    145:        {
                    146:                writer->write_data(writer, this->out);
                    147:                this->out_done = this->out.len;
                    148:                return NEED_MORE;
                    149:        }
                    150:        return INVALID_STATE;
                    151: }
                    152: 
                    153: /**
                    154:  * TLS data exchange loop
                    155:  */
                    156: static bool exchange(private_tls_socket_t *this, bool wr, bool block)
                    157: {
                    158:        char buf[CRYPTO_BUF_SIZE], *pos;
                    159:        ssize_t in, out;
                    160:        size_t len;
                    161:        int flags;
                    162: 
                    163:        while (TRUE)
                    164:        {
                    165:                while (TRUE)
                    166:                {
                    167:                        len = sizeof(buf);
                    168:                        switch (this->tls->build(this->tls, buf, &len, NULL))
                    169:                        {
                    170:                                case NEED_MORE:
                    171:                                case ALREADY_DONE:
                    172:                                        pos = buf;
                    173:                                        while (len)
                    174:                                        {
                    175:                                                out = write(this->fd, pos, len);
                    176:                                                if (out == -1)
                    177:                                                {
                    178:                                                        DBG1(DBG_TLS, "TLS crypto write error: %s",
                    179:                                                                 strerror(errno));
                    180:                                                        return FALSE;
                    181:                                                }
                    182:                                                len -= out;
                    183:                                                pos += out;
                    184:                                        }
                    185:                                        continue;
                    186:                                case INVALID_STATE:
                    187:                                        break;
                    188:                                case SUCCESS:
                    189:                                        return TRUE;
                    190:                                default:
                    191:                                        return FALSE;
                    192:                        }
                    193:                        break;
                    194:                }
                    195:                if (wr)
                    196:                {
                    197:                        if (this->app.out_done == this->app.out.len)
                    198:                        {       /* all data written */
                    199:                                return TRUE;
                    200:                        }
                    201:                }
                    202:                else
                    203:                {
                    204:                        if (this->app.in_done == this->app.in.len)
                    205:                        {       /* buffer fully received */
                    206:                                return TRUE;
                    207:                        }
                    208:                }
                    209: 
                    210:                flags = 0;
                    211:                if (this->app.out_done == this->app.out.len)
                    212:                {
                    213:                        if (!block || this->app.in_done)
                    214:                        {
                    215:                                flags |= MSG_DONTWAIT;
                    216:                        }
                    217:                }
                    218:                in = recv(this->fd, buf, sizeof(buf), flags);
                    219:                if (in < 0)
                    220:                {
                    221:                        if (errno == EAGAIN || errno == EWOULDBLOCK)
                    222:                        {
                    223:                                if (this->app.in_done == 0)
                    224:                                {
                    225:                                        /* reading, nothing got yet, and call would block */
                    226:                                        errno = EWOULDBLOCK;
                    227:                                        this->app.in_done = -1;
                    228:                                }
                    229:                                return TRUE;
                    230:                        }
                    231:                        return FALSE;
                    232:                }
                    233:                if (in == 0)
                    234:                {       /* EOF */
                    235:                        return TRUE;
                    236:                }
                    237:                switch (this->tls->process(this->tls, buf, in))
                    238:                {
                    239:                        case NEED_MORE:
                    240:                                break;
                    241:                        case SUCCESS:
                    242:                                return TRUE;
                    243:                        default:
                    244:                                return FALSE;
                    245:                }
                    246:        }
                    247: }
                    248: 
                    249: METHOD(tls_socket_t, read_, ssize_t,
                    250:        private_tls_socket_t *this, void *buf, size_t len, bool block)
                    251: {
                    252:        if (this->app.cache.len)
                    253:        {
                    254:                size_t cache;
                    255: 
                    256:                cache = min(len, this->app.cache.len - this->app.cache_done);
                    257:                memcpy(buf, this->app.cache.ptr + this->app.cache_done, cache);
                    258: 
                    259:                this->app.cache_done += cache;
                    260:                if (this->app.cache_done == this->app.cache.len)
                    261:                {
                    262:                        chunk_free(&this->app.cache);
                    263:                        this->app.cache_done = 0;
                    264:                }
                    265:                return cache;
                    266:        }
                    267:        this->app.in.ptr = buf;
                    268:        this->app.in.len = len;
                    269:        this->app.in_done = 0;
                    270:        if (exchange(this, FALSE, block))
                    271:        {
                    272:                return this->app.in_done;
                    273:        }
                    274:        return -1;
                    275: }
                    276: 
                    277: METHOD(tls_socket_t, write_, ssize_t,
                    278:        private_tls_socket_t *this, void *buf, size_t len)
                    279: {
                    280:        this->app.out.ptr = buf;
                    281:        this->app.out.len = len;
                    282:        this->app.out_done = 0;
                    283:        if (exchange(this, TRUE, FALSE))
                    284:        {
                    285:                return this->app.out_done;
                    286:        }
                    287:        return -1;
                    288: }
                    289: 
                    290: METHOD(tls_socket_t, splice, bool,
                    291:        private_tls_socket_t *this, int rfd, int wfd)
                    292: {
                    293:        char buf[PLAIN_BUF_SIZE], *pos;
                    294:        ssize_t in, out;
                    295:        bool old, plain_eof = FALSE, crypto_eof = FALSE;
                    296:        struct pollfd pfd[] = {
                    297:                { .fd = this->fd,       .events = POLLIN, },
                    298:                { .fd = rfd,            .events = POLLIN, },
                    299:        };
                    300: 
                    301:        while (!plain_eof && !crypto_eof)
                    302:        {
                    303:                old = thread_cancelability(TRUE);
                    304:                in = poll(pfd, countof(pfd), -1);
                    305:                thread_cancelability(old);
                    306:                if (in == -1)
                    307:                {
                    308:                        DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
                    309:                        return FALSE;
                    310:                }
                    311:                while (!plain_eof && pfd[0].revents & (POLLIN | POLLHUP | POLLNVAL))
                    312:                {
                    313:                        in = read_(this, buf, sizeof(buf), FALSE);
                    314:                        switch (in)
                    315:                        {
                    316:                                case 0:
                    317:                                        plain_eof = TRUE;
                    318:                                        break;
                    319:                                case -1:
                    320:                                        if (errno != EWOULDBLOCK)
                    321:                                        {
                    322:                                                DBG1(DBG_TLS, "TLS read error: %s", strerror(errno));
                    323:                                                return FALSE;
                    324:                                        }
                    325:                                        break;
                    326:                                default:
                    327:                                        pos = buf;
                    328:                                        while (in)
                    329:                                        {
                    330:                                                out = write(wfd, pos, in);
                    331:                                                if (out == -1)
                    332:                                                {
                    333:                                                        DBG1(DBG_TLS, "TLS plain write error: %s",
                    334:                                                                 strerror(errno));
                    335:                                                        return FALSE;
                    336:                                                }
                    337:                                                in -= out;
                    338:                                                pos += out;
                    339:                                        }
                    340:                                        continue;
                    341:                        }
                    342:                        break;
                    343:                }
                    344:                if (!crypto_eof && pfd[1].revents & (POLLIN | POLLHUP | POLLNVAL))
                    345:                {
                    346:                        in = read(rfd, buf, sizeof(buf));
                    347:                        switch (in)
                    348:                        {
                    349:                                case 0:
                    350:                                        crypto_eof = TRUE;
                    351:                                        break;
                    352:                                case -1:
                    353:                                        DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
                    354:                                        return FALSE;
                    355:                                default:
                    356:                                        pos = buf;
                    357:                                        while (in)
                    358:                                        {
                    359:                                                out = write_(this, pos, in);
                    360:                                                if (out == -1)
                    361:                                                {
                    362:                                                        DBG1(DBG_TLS, "TLS write error");
                    363:                                                        return FALSE;
                    364:                                                }
                    365:                                                in -= out;
                    366:                                                pos += out;
                    367:                                        }
                    368:                                        break;
                    369:                        }
                    370:                }
                    371:        }
                    372:        return TRUE;
                    373: }
                    374: 
                    375: METHOD(tls_socket_t, get_fd, int,
                    376:        private_tls_socket_t *this)
                    377: {
                    378:        return this->fd;
                    379: }
                    380: 
                    381: METHOD(tls_socket_t, get_server_id, identification_t*,
                    382:        private_tls_socket_t *this)
                    383: {
                    384:        return this->tls->get_server_id(this->tls);
                    385: }
                    386: 
                    387: METHOD(tls_socket_t, get_peer_id, identification_t*,
                    388:        private_tls_socket_t *this)
                    389: {
                    390:        return this->tls->get_peer_id(this->tls);
                    391: }
                    392: 
                    393: METHOD(tls_socket_t, destroy, void,
                    394:        private_tls_socket_t *this)
                    395: {
                    396:        /* send a TLS close notify if not done yet */
                    397:        this->app.close = TRUE;
                    398:        write_(this, NULL, 0);
                    399:        free(this->app.cache.ptr);
                    400:        this->tls->destroy(this->tls);
                    401:        free(this);
                    402: }
                    403: 
                    404: /**
                    405:  * See header
                    406:  */
                    407: tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
                    408:                                                        identification_t *peer, int fd, tls_cache_t *cache,
                    409:                                                        tls_version_t max_version, bool nullok)
                    410: {
                    411:        private_tls_socket_t *this;
                    412:        tls_purpose_t purpose;
                    413: 
                    414:        INIT(this,
                    415:                .public = {
                    416:                        .read = _read_,
                    417:                        .write = _write_,
                    418:                        .splice = _splice,
                    419:                        .get_fd = _get_fd,
                    420:                        .get_server_id = _get_server_id,
                    421:                        .get_peer_id = _get_peer_id,
                    422:                        .destroy = _destroy,
                    423:                },
                    424:                .app = {
                    425:                        .application = {
                    426:                                .build = _build,
                    427:                                .process = _process,
                    428:                                .destroy = (void*)nop,
                    429:                        },
                    430:                },
                    431:                .fd = fd,
                    432:        );
                    433: 
                    434:        if (nullok)
                    435:        {
                    436:                purpose = TLS_PURPOSE_GENERIC_NULLOK;
                    437:        }
                    438:        else
                    439:        {
                    440:                purpose = TLS_PURPOSE_GENERIC;
                    441:        }
                    442: 
                    443:        this->tls = tls_create(is_server, server, peer, purpose,
                    444:                                                   &this->app.application, cache);
                    445:        if (!this->tls)
                    446:        {
                    447:                free(this);
                    448:                return NULL;
                    449:        }
                    450:        this->tls->set_version(this->tls, max_version);
                    451: 
                    452:        return &this->public;
                    453: }

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>