Annotation of embedaddon/strongswan/src/libtls/tls_socket.c, revision 1.1.1.2

1.1       misho       1: /*
                      2:  * Copyright (C) 2010 Martin Willi
                      3:  * Copyright (C) 2010 revosec AG
                      4:  *
                      5:  * This program is free software; you can redistribute it and/or modify it
                      6:  * under the terms of the GNU General Public License as published by the
                      7:  * Free Software Foundation; either version 2 of the License, or (at your
                      8:  * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
                      9:  *
                     10:  * This program is distributed in the hope that it will be useful, but
                     11:  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
                     12:  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
                     13:  * for more details.
                     14:  */
                     15: 
                     16: #include "tls_socket.h"
                     17: 
                     18: #include <unistd.h>
                     19: #include <errno.h>
                     20: 
                     21: #include <utils/debug.h>
                     22: #include <threading/thread.h>
                     23: 
                     24: /**
                     25:  * Buffer size for plain side I/O
                     26:  */
                     27: #define PLAIN_BUF_SIZE TLS_MAX_FRAGMENT_LEN
                     28: 
                     29: /**
                     30:  * Buffer size for encrypted side I/O
                     31:  */
                     32: #define CRYPTO_BUF_SIZE        TLS_MAX_FRAGMENT_LEN + 2048
                     33: 
                     34: typedef struct private_tls_socket_t private_tls_socket_t;
                     35: typedef struct private_tls_application_t private_tls_application_t;
                     36: 
                     37: struct private_tls_application_t {
                     38: 
                     39:        /**
                     40:         * Implements tls_application layer.
                     41:         */
                     42:        tls_application_t application;
                     43: 
                     44:        /**
                     45:         * Output buffer to write to
                     46:         */
                     47:        chunk_t out;
                     48: 
                     49:        /**
                     50:         * Number of bytes written to out
                     51:         */
                     52:        size_t out_done;
                     53: 
                     54:        /**
                     55:         * Input buffer to read to
                     56:         */
                     57:        chunk_t in;
                     58: 
                     59:        /**
                     60:         * Number of bytes read to in
                     61:         */
                     62:        size_t in_done;
                     63: 
                     64:        /**
                     65:         * Cached input data
                     66:         */
                     67:        chunk_t cache;
                     68: 
                     69:        /**
                     70:         * Bytes consumed in cache
                     71:         */
                     72:        size_t cache_done;
                     73: 
                     74:        /**
                     75:         * Close TLS connection?
                     76:         */
                     77:        bool close;
                     78: };
                     79: 
                     80: /**
                     81:  * Private data of an tls_socket_t object.
                     82:  */
                     83: struct private_tls_socket_t {
                     84: 
                     85:        /**
                     86:         * Public tls_socket_t interface.
                     87:         */
                     88:        tls_socket_t public;
                     89: 
                     90:        /**
                     91:         * TLS application implementation
                     92:         */
                     93:        private_tls_application_t app;
                     94: 
                     95:        /**
                     96:         * TLS stack
                     97:         */
                     98:        tls_t *tls;
                     99: 
                    100:        /**
                    101:         * Underlying OS socket
                    102:         */
                    103:        int fd;
1.1.1.2 ! misho     104: 
        !           105:        /**
        !           106:         * Whether the socket returned EOF
        !           107:         */
        !           108:        bool eof;
1.1       misho     109: };
                    110: 
                    111: METHOD(tls_application_t, process, status_t,
                    112:        private_tls_application_t *this, bio_reader_t *reader)
                    113: {
                    114:        chunk_t data;
                    115:        size_t len;
                    116: 
                    117:        if (this->close)
                    118:        {
                    119:                return SUCCESS;
                    120:        }
                    121:        len = min(reader->remaining(reader), this->in.len - this->in_done);
                    122:        if (len)
                    123:        {       /* copy to read buffer as much as fits in */
                    124:                if (!reader->read_data(reader, len, &data))
                    125:                {
                    126:                        return FAILED;
                    127:                }
                    128:                memcpy(this->in.ptr + this->in_done, data.ptr, data.len);
                    129:                this->in_done += data.len;
                    130:        }
                    131:        else
                    132:        {       /* read buffer is full, cache for next read */
                    133:                if (!reader->read_data(reader, reader->remaining(reader), &data))
                    134:                {
                    135:                        return FAILED;
                    136:                }
                    137:                this->cache = chunk_cat("mc", this->cache, data);
                    138:        }
                    139:        return NEED_MORE;
                    140: }
                    141: 
                    142: METHOD(tls_application_t, build, status_t,
                    143:        private_tls_application_t *this, bio_writer_t *writer)
                    144: {
                    145:        if (this->close)
                    146:        {
                    147:                return SUCCESS;
                    148:        }
                    149:        if (this->out.len > this->out_done)
                    150:        {
                    151:                writer->write_data(writer, this->out);
                    152:                this->out_done = this->out.len;
                    153:                return NEED_MORE;
                    154:        }
                    155:        return INVALID_STATE;
                    156: }
                    157: 
                    158: /**
                    159:  * TLS data exchange loop
                    160:  */
                    161: static bool exchange(private_tls_socket_t *this, bool wr, bool block)
                    162: {
                    163:        char buf[CRYPTO_BUF_SIZE], *pos;
                    164:        ssize_t in, out;
                    165:        size_t len;
                    166:        int flags;
                    167: 
                    168:        while (TRUE)
                    169:        {
                    170:                while (TRUE)
                    171:                {
                    172:                        len = sizeof(buf);
                    173:                        switch (this->tls->build(this->tls, buf, &len, NULL))
                    174:                        {
                    175:                                case NEED_MORE:
                    176:                                case ALREADY_DONE:
                    177:                                        pos = buf;
                    178:                                        while (len)
                    179:                                        {
                    180:                                                out = write(this->fd, pos, len);
                    181:                                                if (out == -1)
                    182:                                                {
                    183:                                                        DBG1(DBG_TLS, "TLS crypto write error: %s",
                    184:                                                                 strerror(errno));
                    185:                                                        return FALSE;
                    186:                                                }
                    187:                                                len -= out;
                    188:                                                pos += out;
                    189:                                        }
                    190:                                        continue;
                    191:                                case INVALID_STATE:
                    192:                                        break;
                    193:                                case SUCCESS:
                    194:                                        return TRUE;
                    195:                                default:
1.1.1.2 ! misho     196:                                        if (wr)
        !           197:                                        {
        !           198:                                                return FALSE;
        !           199:                                        }
        !           200:                                        break;
1.1       misho     201:                        }
                    202:                        break;
                    203:                }
                    204:                if (wr)
                    205:                {
                    206:                        if (this->app.out_done == this->app.out.len)
                    207:                        {       /* all data written */
                    208:                                return TRUE;
                    209:                        }
                    210:                }
                    211:                else
                    212:                {
                    213:                        if (this->app.in_done == this->app.in.len)
                    214:                        {       /* buffer fully received */
                    215:                                return TRUE;
                    216:                        }
                    217:                }
                    218: 
                    219:                flags = 0;
                    220:                if (this->app.out_done == this->app.out.len)
                    221:                {
                    222:                        if (!block || this->app.in_done)
                    223:                        {
                    224:                                flags |= MSG_DONTWAIT;
                    225:                        }
                    226:                }
                    227:                in = recv(this->fd, buf, sizeof(buf), flags);
                    228:                if (in < 0)
                    229:                {
                    230:                        if (errno == EAGAIN || errno == EWOULDBLOCK)
                    231:                        {
                    232:                                if (this->app.in_done == 0)
                    233:                                {
                    234:                                        /* reading, nothing got yet, and call would block */
                    235:                                        errno = EWOULDBLOCK;
                    236:                                        this->app.in_done = -1;
                    237:                                }
                    238:                                return TRUE;
                    239:                        }
                    240:                        return FALSE;
                    241:                }
                    242:                if (in == 0)
                    243:                {       /* EOF */
1.1.1.2 ! misho     244:                        this->eof = TRUE;
1.1       misho     245:                        return TRUE;
                    246:                }
                    247:                switch (this->tls->process(this->tls, buf, in))
                    248:                {
                    249:                        case NEED_MORE:
                    250:                                break;
                    251:                        case SUCCESS:
                    252:                                return TRUE;
                    253:                        default:
                    254:                                return FALSE;
                    255:                }
                    256:        }
                    257: }
                    258: 
                    259: METHOD(tls_socket_t, read_, ssize_t,
                    260:        private_tls_socket_t *this, void *buf, size_t len, bool block)
                    261: {
                    262:        if (this->app.cache.len)
                    263:        {
                    264:                size_t cache;
                    265: 
                    266:                cache = min(len, this->app.cache.len - this->app.cache_done);
                    267:                memcpy(buf, this->app.cache.ptr + this->app.cache_done, cache);
                    268: 
                    269:                this->app.cache_done += cache;
                    270:                if (this->app.cache_done == this->app.cache.len)
                    271:                {
                    272:                        chunk_free(&this->app.cache);
                    273:                        this->app.cache_done = 0;
                    274:                }
                    275:                return cache;
                    276:        }
1.1.1.2 ! misho     277:        if (this->eof)
        !           278:        {
        !           279:                return 0;
        !           280:        }
1.1       misho     281:        this->app.in.ptr = buf;
                    282:        this->app.in.len = len;
                    283:        this->app.in_done = 0;
                    284:        if (exchange(this, FALSE, block))
                    285:        {
1.1.1.2 ! misho     286:                if (!this->app.in_done && !this->eof)
        !           287:                {
        !           288:                        errno = EWOULDBLOCK;
        !           289:                        return -1;
        !           290:                }
1.1       misho     291:                return this->app.in_done;
                    292:        }
                    293:        return -1;
                    294: }
                    295: 
                    296: METHOD(tls_socket_t, write_, ssize_t,
                    297:        private_tls_socket_t *this, void *buf, size_t len)
                    298: {
                    299:        this->app.out.ptr = buf;
                    300:        this->app.out.len = len;
                    301:        this->app.out_done = 0;
                    302:        if (exchange(this, TRUE, FALSE))
                    303:        {
                    304:                return this->app.out_done;
                    305:        }
                    306:        return -1;
                    307: }
                    308: 
                    309: METHOD(tls_socket_t, splice, bool,
                    310:        private_tls_socket_t *this, int rfd, int wfd)
                    311: {
                    312:        char buf[PLAIN_BUF_SIZE], *pos;
                    313:        ssize_t in, out;
1.1.1.2 ! misho     314:        bool old, crypto_eof = FALSE;
1.1       misho     315:        struct pollfd pfd[] = {
                    316:                { .fd = this->fd,       .events = POLLIN, },
                    317:                { .fd = rfd,            .events = POLLIN, },
                    318:        };
                    319: 
1.1.1.2 ! misho     320:        while (!this->eof && !crypto_eof)
1.1       misho     321:        {
                    322:                old = thread_cancelability(TRUE);
                    323:                in = poll(pfd, countof(pfd), -1);
                    324:                thread_cancelability(old);
                    325:                if (in == -1)
                    326:                {
                    327:                        DBG1(DBG_TLS, "TLS select error: %s", strerror(errno));
                    328:                        return FALSE;
                    329:                }
1.1.1.2 ! misho     330:                while (!this->eof && pfd[0].revents & (POLLIN | POLLHUP | POLLNVAL))
1.1       misho     331:                {
                    332:                        in = read_(this, buf, sizeof(buf), FALSE);
                    333:                        switch (in)
                    334:                        {
                    335:                                case -1:
                    336:                                        if (errno != EWOULDBLOCK)
                    337:                                        {
                    338:                                                DBG1(DBG_TLS, "TLS read error: %s", strerror(errno));
                    339:                                                return FALSE;
                    340:                                        }
                    341:                                        break;
                    342:                                default:
                    343:                                        pos = buf;
                    344:                                        while (in)
                    345:                                        {
                    346:                                                out = write(wfd, pos, in);
                    347:                                                if (out == -1)
                    348:                                                {
                    349:                                                        DBG1(DBG_TLS, "TLS plain write error: %s",
                    350:                                                                 strerror(errno));
                    351:                                                        return FALSE;
                    352:                                                }
                    353:                                                in -= out;
                    354:                                                pos += out;
                    355:                                        }
                    356:                                        continue;
                    357:                        }
                    358:                        break;
                    359:                }
                    360:                if (!crypto_eof && pfd[1].revents & (POLLIN | POLLHUP | POLLNVAL))
                    361:                {
                    362:                        in = read(rfd, buf, sizeof(buf));
                    363:                        switch (in)
                    364:                        {
                    365:                                case 0:
                    366:                                        crypto_eof = TRUE;
                    367:                                        break;
                    368:                                case -1:
                    369:                                        DBG1(DBG_TLS, "TLS plain read error: %s", strerror(errno));
                    370:                                        return FALSE;
                    371:                                default:
                    372:                                        pos = buf;
                    373:                                        while (in)
                    374:                                        {
                    375:                                                out = write_(this, pos, in);
                    376:                                                if (out == -1)
                    377:                                                {
                    378:                                                        DBG1(DBG_TLS, "TLS write error");
                    379:                                                        return FALSE;
                    380:                                                }
                    381:                                                in -= out;
                    382:                                                pos += out;
                    383:                                        }
                    384:                                        break;
                    385:                        }
                    386:                }
                    387:        }
                    388:        return TRUE;
                    389: }
                    390: 
                    391: METHOD(tls_socket_t, get_fd, int,
                    392:        private_tls_socket_t *this)
                    393: {
                    394:        return this->fd;
                    395: }
                    396: 
                    397: METHOD(tls_socket_t, get_server_id, identification_t*,
                    398:        private_tls_socket_t *this)
                    399: {
                    400:        return this->tls->get_server_id(this->tls);
                    401: }
                    402: 
                    403: METHOD(tls_socket_t, get_peer_id, identification_t*,
                    404:        private_tls_socket_t *this)
                    405: {
                    406:        return this->tls->get_peer_id(this->tls);
                    407: }
                    408: 
                    409: METHOD(tls_socket_t, destroy, void,
                    410:        private_tls_socket_t *this)
                    411: {
                    412:        /* send a TLS close notify if not done yet */
                    413:        this->app.close = TRUE;
                    414:        write_(this, NULL, 0);
                    415:        free(this->app.cache.ptr);
                    416:        this->tls->destroy(this->tls);
                    417:        free(this);
                    418: }
                    419: 
                    420: /**
                    421:  * See header
                    422:  */
                    423: tls_socket_t *tls_socket_create(bool is_server, identification_t *server,
1.1.1.2 ! misho     424:                                                                identification_t *peer, int fd,
        !           425:                                                                tls_cache_t *cache, tls_version_t min_version,
        !           426:                                                                tls_version_t max_version, tls_flag_t flags)
1.1       misho     427: {
                    428:        private_tls_socket_t *this;
                    429: 
                    430:        INIT(this,
                    431:                .public = {
                    432:                        .read = _read_,
                    433:                        .write = _write_,
                    434:                        .splice = _splice,
                    435:                        .get_fd = _get_fd,
                    436:                        .get_server_id = _get_server_id,
                    437:                        .get_peer_id = _get_peer_id,
                    438:                        .destroy = _destroy,
                    439:                },
                    440:                .app = {
                    441:                        .application = {
                    442:                                .build = _build,
                    443:                                .process = _process,
                    444:                                .destroy = (void*)nop,
                    445:                        },
                    446:                },
                    447:                .fd = fd,
                    448:        );
                    449: 
1.1.1.2 ! misho     450:        this->tls = tls_create(is_server, server, peer, TLS_PURPOSE_GENERIC,
        !           451:                                                   &this->app.application, cache, flags);
        !           452:        if (!this->tls ||
        !           453:                !this->tls->set_version(this->tls, min_version, max_version))
1.1       misho     454:        {
                    455:                free(this);
                    456:                return NULL;
                    457:        }
                    458:        return &this->public;
                    459: }

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>