Annotation of embedaddon/strongswan/src/libcharon/plugins/vici/vici_socket.c, revision 1.1.1.1

1.1       misho       1: /*
                      2:  * Copyright (C) 2014 Martin Willi
                      3:  * Copyright (C) 2014 revosec AG
                      4:  *
                      5:  * This program is free software; you can redistribute it and/or modify it
                      6:  * under the terms of the GNU General Public License as published by the
                      7:  * Free Software Foundation; either version 2 of the License, or (at your
                      8:  * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
                      9:  *
                     10:  * This program is distributed in the hope that it will be useful, but
                     11:  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
                     12:  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
                     13:  * for more details.
                     14:  */
                     15: 
                     16: #include "vici_socket.h"
                     17: 
                     18: #include <threading/mutex.h>
                     19: #include <threading/condvar.h>
                     20: #include <threading/thread.h>
                     21: #include <collections/array.h>
                     22: #include <collections/linked_list.h>
                     23: #include <processing/jobs/callback_job.h>
                     24: 
                     25: #include <errno.h>
                     26: #include <string.h>
                     27: 
                     28: typedef struct private_vici_socket_t private_vici_socket_t;
                     29: 
                     30: /**
                     31:  * Private members of vici_socket_t
                     32:  */
                     33: struct private_vici_socket_t {
                     34: 
                     35:        /**
                     36:         * public functions
                     37:         */
                     38:        vici_socket_t public;
                     39: 
                     40:        /**
                     41:         * Inbound message callback
                     42:         */
                     43:        vici_inbound_cb_t inbound;
                     44: 
                     45:        /**
                     46:         * Client connect callback
                     47:         */
                     48:        vici_connect_cb_t connect;
                     49: 
                     50:        /**
                     51:         * Client disconnect callback
                     52:         */
                     53:        vici_disconnect_cb_t disconnect;
                     54: 
                     55:        /**
                     56:         * Next client connection identifier
                     57:         */
                     58:        u_int nextid;
                     59: 
                     60:        /**
                     61:         * User data for callbacks
                     62:         */
                     63:        void *user;
                     64: 
                     65:        /**
                     66:         * Service accepting vici connections
                     67:         */
                     68:        stream_service_t *service;
                     69: 
                     70:        /**
                     71:         * Client connections, as entry_t
                     72:         */
                     73:        linked_list_t *connections;
                     74: 
                     75:        /**
                     76:         * mutex for client connections
                     77:         */
                     78:        mutex_t *mutex;
                     79: };
                     80: 
                     81: /**
                     82:  * Data to securely reference an entry
                     83:  */
                     84: typedef struct {
                     85:        /* reference to socket instance */
                     86:        private_vici_socket_t *this;
                     87:        /** connection identifier of entry */
                     88:        u_int id;
                     89: } entry_selector_t;
                     90: 
                     91: /**
                     92:  * Partially processed message
                     93:  */
                     94: typedef struct {
                     95:        /** bytes of length header sent/received */
                     96:        u_char hdrlen;
                     97:        /** bytes of length header */
                     98:        char hdr[sizeof(uint32_t)];
                     99:        /** send/receive buffer on heap */
                    100:        chunk_t buf;
                    101:        /** bytes sent/received in buffer */
                    102:        uint32_t done;
                    103: } msg_buf_t;
                    104: 
                    105: /**
                    106:  * Client connection entry
                    107:  */
                    108: typedef struct {
                    109:        /** reference to socket */
                    110:        private_vici_socket_t *this;
                    111:        /** associated stream */
                    112:        stream_t *stream;
                    113:        /** queued messages to send, as msg_buf_t pointers */
                    114:        array_t *out;
                    115:        /** input message buffer */
                    116:        msg_buf_t in;
                    117:        /** queued input messages to process, as chunk_t */
                    118:        array_t *queue;
                    119:        /** do we have job processing input queue? */
                    120:        bool has_processor;
                    121:        /** is this client disconnecting */
                    122:        bool disconnecting;
                    123:        /** client connection identifier */
                    124:        u_int id;
                    125:        /** any users reading over this connection? */
                    126:        int readers;
                    127:        /** any users writing over this connection? */
                    128:        int writers;
                    129:        /** condvar to wait for usage  */
                    130:        condvar_t *cond;
                    131: } entry_t;
                    132: 
                    133: /**
                    134:  * Destroy an connection entry
                    135:  */
                    136: CALLBACK(destroy_entry, void,
                    137:        entry_t *entry)
                    138: {
                    139:        msg_buf_t *out;
                    140:        chunk_t chunk;
                    141: 
                    142:        entry->stream->destroy(entry->stream);
                    143:        entry->this->disconnect(entry->this->user, entry->id);
                    144:        entry->cond->destroy(entry->cond);
                    145: 
                    146:        while (array_remove(entry->out, ARRAY_TAIL, &out))
                    147:        {
                    148:                chunk_clear(&out->buf);
                    149:                free(out);
                    150:        }
                    151:        array_destroy(entry->out);
                    152:        while (array_remove(entry->queue, ARRAY_TAIL, &chunk))
                    153:        {
                    154:                chunk_clear(&chunk);
                    155:        }
                    156:        array_destroy(entry->queue);
                    157:        chunk_clear(&entry->in.buf);
                    158:        free(entry);
                    159: }
                    160: 
                    161: /**
                    162:  * Find entry by stream (if given) or id, claim use
                    163:  */
                    164: static entry_t* find_entry(private_vici_socket_t *this, stream_t *stream,
                    165:                                                   u_int id, bool reader, bool writer)
                    166: {
                    167:        enumerator_t *enumerator;
                    168:        entry_t *entry, *found = NULL;
                    169:        bool candidate = TRUE;
                    170: 
                    171:        this->mutex->lock(this->mutex);
                    172:        while (candidate && !found)
                    173:        {
                    174:                candidate = FALSE;
                    175:                enumerator = this->connections->create_enumerator(this->connections);
                    176:                while (enumerator->enumerate(enumerator, &entry))
                    177:                {
                    178:                        if (stream)
                    179:                        {
                    180:                                if (entry->stream != stream)
                    181:                                {
                    182:                                        continue;
                    183:                                }
                    184:                        }
                    185:                        else
                    186:                        {
                    187:                                if (entry->id != id)
                    188:                                {
                    189:                                        continue;
                    190:                                }
                    191:                        }
                    192:                        if (entry->disconnecting)
                    193:                        {
                    194:                                continue;
                    195:                        }
                    196:                        candidate = TRUE;
                    197: 
                    198:                        if ((reader && entry->readers) ||
                    199:                                (writer && entry->writers))
                    200:                        {
                    201:                                entry->cond->wait(entry->cond, this->mutex);
                    202:                                break;
                    203:                        }
                    204:                        if (reader)
                    205:                        {
                    206:                                entry->readers++;
                    207:                        }
                    208:                        if (writer)
                    209:                        {
                    210:                                entry->writers++;
                    211:                        }
                    212:                        found = entry;
                    213:                        break;
                    214:                }
                    215:                enumerator->destroy(enumerator);
                    216:        }
                    217:        this->mutex->unlock(this->mutex);
                    218: 
                    219:        return found;
                    220: }
                    221: 
                    222: /**
                    223:  * Remove entry by id, claim use
                    224:  */
                    225: static entry_t* remove_entry(private_vici_socket_t *this, u_int id)
                    226: {
                    227:        enumerator_t *enumerator;
                    228:        entry_t *entry, *found = NULL;
                    229:        bool candidate = TRUE;
                    230: 
                    231:        this->mutex->lock(this->mutex);
                    232:        while (candidate && !found)
                    233:        {
                    234:                candidate = FALSE;
                    235:                enumerator = this->connections->create_enumerator(this->connections);
                    236:                while (enumerator->enumerate(enumerator, &entry))
                    237:                {
                    238:                        if (entry->id == id)
                    239:                        {
                    240:                                candidate = TRUE;
                    241:                                if (entry->readers || entry->writers)
                    242:                                {
                    243:                                        entry->cond->wait(entry->cond, this->mutex);
                    244:                                        break;
                    245:                                }
                    246:                                this->connections->remove_at(this->connections, enumerator);
                    247:                                found = entry;
                    248:                                break;
                    249:                        }
                    250:                }
                    251:                enumerator->destroy(enumerator);
                    252:        }
                    253:        this->mutex->unlock(this->mutex);
                    254: 
                    255:        return found;
                    256: }
                    257: 
                    258: /**
                    259:  * Release a claimed entry
                    260:  */
                    261: static void put_entry(private_vici_socket_t *this, entry_t *entry,
                    262:                                          bool reader, bool writer)
                    263: {
                    264:        this->mutex->lock(this->mutex);
                    265:        if (reader)
                    266:        {
                    267:                entry->readers--;
                    268:        }
                    269:        if (writer)
                    270:        {
                    271:                entry->writers--;
                    272:        }
                    273:        entry->cond->signal(entry->cond);
                    274:        this->mutex->unlock(this->mutex);
                    275: }
                    276: 
                    277: /**
                    278:  * Asynchronous callback to disconnect client
                    279:  */
                    280: CALLBACK(disconnect_async, job_requeue_t,
                    281:        entry_selector_t *sel)
                    282: {
                    283:        entry_t *entry;
                    284: 
                    285:        entry = remove_entry(sel->this, sel->id);
                    286:        if (entry)
                    287:        {
                    288:                destroy_entry(entry);
                    289:        }
                    290:        return JOB_REQUEUE_NONE;
                    291: }
                    292: 
                    293: /**
                    294:  * Disconnect a connected client
                    295:  */
                    296: static void disconnect(private_vici_socket_t *this, u_int id)
                    297: {
                    298:        entry_selector_t *sel;
                    299: 
                    300:        INIT(sel,
                    301:                .this = this,
                    302:                .id = id,
                    303:        );
                    304: 
                    305:        lib->processor->queue_job(lib->processor,
                    306:                        (job_t*)callback_job_create(disconnect_async, sel, free, NULL));
                    307: }
                    308: 
                    309: /**
                    310:  * Write queued output data
                    311:  */
                    312: static bool do_write(private_vici_socket_t *this, entry_t *entry,
                    313:                                         stream_t *stream, char *errmsg, size_t errlen)
                    314: {
                    315:        msg_buf_t *out;
                    316:        ssize_t len;
                    317: 
                    318:        while (array_get(entry->out, ARRAY_HEAD, &out))
                    319:        {
                    320:                /* write header */
                    321:                while (out->hdrlen < sizeof(out->hdr))
                    322:                {
                    323:                        len = stream->write(stream, out->hdr + out->hdrlen,
                    324:                                                                sizeof(out->hdr) - out->hdrlen, FALSE);
                    325:                        if (len == 0)
                    326:                        {
                    327:                                return FALSE;
                    328:                        }
                    329:                        if (len < 0)
                    330:                        {
                    331:                                if (errno == EWOULDBLOCK)
                    332:                                {
                    333:                                        return TRUE;
                    334:                                }
                    335:                                snprintf(errmsg, errlen, "vici header write error: %s",
                    336:                                                 strerror(errno));
                    337:                                return FALSE;
                    338:                        }
                    339:                        out->hdrlen += len;
                    340:                }
                    341: 
                    342:                /* write buffer buffer */
                    343:                while (out->buf.len > out->done)
                    344:                {
                    345:                        len = stream->write(stream, out->buf.ptr + out->done,
                    346:                                                                out->buf.len - out->done, FALSE);
                    347:                        if (len == 0)
                    348:                        {
                    349:                                snprintf(errmsg, errlen, "premature vici disconnect");
                    350:                                return FALSE;
                    351:                        }
                    352:                        if (len < 0)
                    353:                        {
                    354:                                if (errno == EWOULDBLOCK)
                    355:                                {
                    356:                                        return TRUE;
                    357:                                }
                    358:                                snprintf(errmsg, errlen, "vici write error: %s", strerror(errno));
                    359:                                return FALSE;
                    360:                        }
                    361:                        out->done += len;
                    362:                }
                    363: 
                    364:                if (array_remove(entry->out, ARRAY_HEAD, &out))
                    365:                {
                    366:                        chunk_clear(&out->buf);
                    367:                        free(out);
                    368:                }
                    369:        }
                    370:        return TRUE;
                    371: }
                    372: 
                    373: /**
                    374:  * Send pending messages
                    375:  */
                    376: CALLBACK(on_write, bool,
                    377:        private_vici_socket_t *this, stream_t *stream)
                    378: {
                    379:        char errmsg[256] = "";
                    380:        entry_t *entry;
                    381:        bool ret = FALSE;
                    382: 
                    383:        entry = find_entry(this, stream, 0, FALSE, TRUE);
                    384:        if (entry)
                    385:        {
                    386:                ret = do_write(this, entry, stream, errmsg, sizeof(errmsg));
                    387:                if (ret)
                    388:                {
                    389:                        /* unregister if we have no more messages to send */
                    390:                        ret = array_count(entry->out) != 0;
                    391:                }
                    392:                else
                    393:                {
                    394:                        entry->disconnecting = TRUE;
                    395:                        disconnect(entry->this, entry->id);
                    396:                }
                    397:                put_entry(this, entry, FALSE, TRUE);
                    398: 
                    399:                if (!ret && errmsg[0])
                    400:                {
                    401:                        DBG1(DBG_CFG, errmsg);
                    402:                }
                    403:        }
                    404: 
                    405:        return ret;
                    406: }
                    407: 
                    408: /**
                    409:  * Read in available header with data, non-blocking accumulating to buffer
                    410:  */
                    411: static bool do_read(private_vici_socket_t *this, entry_t *entry,
                    412:                                        stream_t *stream, char *errmsg, size_t errlen)
                    413: {
                    414:        uint32_t msglen;
                    415:        ssize_t len;
                    416: 
                    417:        /* assemble the length header first */
                    418:        while (entry->in.hdrlen < sizeof(entry->in.hdr))
                    419:        {
                    420:                len = stream->read(stream, entry->in.hdr + entry->in.hdrlen,
                    421:                                                   sizeof(entry->in.hdr) - entry->in.hdrlen, FALSE);
                    422:                if (len == 0)
                    423:                {
                    424:                        return FALSE;
                    425:                }
                    426:                if (len < 0)
                    427:                {
                    428:                        if (errno == EWOULDBLOCK)
                    429:                        {
                    430:                                return TRUE;
                    431:                        }
                    432:                        snprintf(errmsg, errlen, "vici header read error: %s",
                    433:                                         strerror(errno));
                    434:                        return FALSE;
                    435:                }
                    436:                entry->in.hdrlen += len;
                    437:                if (entry->in.hdrlen == sizeof(entry->in.hdr))
                    438:                {
                    439:                        msglen = untoh32(entry->in.hdr);
                    440:                        if (msglen > VICI_MESSAGE_SIZE_MAX)
                    441:                        {
                    442:                                snprintf(errmsg, errlen, "vici message length %u exceeds %u "
                    443:                                                 "bytes limit, ignored", msglen, VICI_MESSAGE_SIZE_MAX);
                    444:                                return FALSE;
                    445:                        }
                    446:                        /* header complete, continue with data */
                    447:                        entry->in.buf = chunk_alloc(msglen);
                    448:                }
                    449:        }
                    450: 
                    451:        /* assemble buffer */
                    452:        while (entry->in.buf.len > entry->in.done)
                    453:        {
                    454:                len = stream->read(stream, entry->in.buf.ptr + entry->in.done,
                    455:                                                   entry->in.buf.len - entry->in.done, FALSE);
                    456:                if (len == 0)
                    457:                {
                    458:                        snprintf(errmsg, errlen, "premature vici disconnect");
                    459:                        return FALSE;
                    460:                }
                    461:                if (len < 0)
                    462:                {
                    463:                        if (errno == EWOULDBLOCK)
                    464:                        {
                    465:                                return TRUE;
                    466:                        }
                    467:                        snprintf(errmsg, errlen, "vici read error: %s", strerror(errno));
                    468:                        return FALSE;
                    469:                }
                    470:                entry->in.done += len;
                    471:        }
                    472: 
                    473:        return TRUE;
                    474: }
                    475: 
                    476: /**
                    477:  * Callback processing incoming requests in strict order
                    478:  */
                    479: CALLBACK(process_queue, job_requeue_t,
                    480:        entry_selector_t *sel)
                    481: {
                    482:        entry_t *entry;
                    483:        chunk_t chunk;
                    484:        bool found;
                    485:        u_int id;
                    486: 
                    487:        while (TRUE)
                    488:        {
                    489:                entry = find_entry(sel->this, NULL, sel->id, TRUE, FALSE);
                    490:                if (!entry)
                    491:                {
                    492:                        break;
                    493:                }
                    494: 
                    495:                found = array_remove(entry->queue, ARRAY_HEAD, &chunk);
                    496:                if (!found)
                    497:                {
                    498:                        entry->has_processor = FALSE;
                    499:                }
                    500:                id = entry->id;
                    501:                put_entry(sel->this, entry, TRUE, FALSE);
                    502:                if (!found)
                    503:                {
                    504:                        break;
                    505:                }
                    506: 
                    507:                thread_cleanup_push(free, chunk.ptr);
                    508:                sel->this->inbound(sel->this->user, id, chunk);
                    509:                thread_cleanup_pop(TRUE);
                    510:        }
                    511:        return JOB_REQUEUE_NONE;
                    512: }
                    513: 
                    514: /**
                    515:  * Process incoming messages
                    516:  */
                    517: CALLBACK(on_read, bool,
                    518:        private_vici_socket_t *this, stream_t *stream)
                    519: {
                    520:        char errmsg[256] = "";
                    521:        entry_selector_t *sel;
                    522:        entry_t *entry;
                    523:        bool ret = FALSE;
                    524: 
                    525:        entry = find_entry(this, stream, 0, TRUE, FALSE);
                    526:        if (entry)
                    527:        {
                    528:                ret = do_read(this, entry, stream, errmsg, sizeof(errmsg));
                    529:                if (!ret)
                    530:                {
                    531:                        entry->disconnecting = TRUE;
                    532:                        disconnect(this, entry->id);
                    533:                }
                    534:                else if (entry->in.hdrlen == sizeof(entry->in.hdr) &&
                    535:                                 entry->in.buf.len == entry->in.done)
                    536:                {
                    537:                        array_insert(entry->queue, ARRAY_TAIL, &entry->in.buf);
                    538:                        entry->in.buf = chunk_empty;
                    539:                        entry->in.hdrlen = entry->in.done = 0;
                    540: 
                    541:                        if (!entry->has_processor)
                    542:                        {
                    543:                                INIT(sel,
                    544:                                        .this = this,
                    545:                                        .id = entry->id,
                    546:                                );
                    547:                                lib->processor->queue_job(lib->processor,
                    548:                                                        (job_t*)callback_job_create(process_queue,
                    549:                                                                                                                sel, free, NULL));
                    550:                                entry->has_processor = TRUE;
                    551:                        }
                    552:                }
                    553:                put_entry(this, entry, TRUE, FALSE);
                    554: 
                    555:                if (!ret && errmsg[0])
                    556:                {
                    557:                        DBG1(DBG_CFG, errmsg);
                    558:                }
                    559:        }
                    560: 
                    561:        return ret;
                    562: }
                    563: 
                    564: /**
                    565:  * Process connection request
                    566:  */
                    567: CALLBACK(on_accept, bool,
                    568:        private_vici_socket_t *this, stream_t *stream)
                    569: {
                    570:        entry_t *entry;
                    571:        u_int id;
                    572: 
                    573:        id = ref_get(&this->nextid);
                    574: 
                    575:        INIT(entry,
                    576:                .this = this,
                    577:                .stream = stream,
                    578:                .id = id,
                    579:                .out = array_create(0, 0),
                    580:                .queue = array_create(sizeof(chunk_t), 0),
                    581:                .cond = condvar_create(CONDVAR_TYPE_DEFAULT),
                    582:                .readers = 1,
                    583:        );
                    584: 
                    585:        this->mutex->lock(this->mutex);
                    586:        this->connections->insert_last(this->connections, entry);
                    587:        this->mutex->unlock(this->mutex);
                    588: 
                    589:        stream->on_read(stream, on_read, this);
                    590: 
                    591:        put_entry(this, entry, TRUE, FALSE);
                    592: 
                    593:        this->connect(this->user, id);
                    594: 
                    595:        return TRUE;
                    596: }
                    597: 
                    598: /**
                    599:  * Async callback to enable writer
                    600:  */
                    601: CALLBACK(enable_writer, job_requeue_t,
                    602:        entry_selector_t *sel)
                    603: {
                    604:        entry_t *entry;
                    605: 
                    606:        entry = find_entry(sel->this, NULL, sel->id, FALSE, TRUE);
                    607:        if (entry)
                    608:        {
                    609:                entry->stream->on_write(entry->stream, on_write, sel->this);
                    610:                put_entry(sel->this, entry, FALSE, TRUE);
                    611:        }
                    612:        return JOB_REQUEUE_NONE;
                    613: }
                    614: 
                    615: METHOD(vici_socket_t, send_, void,
                    616:        private_vici_socket_t *this, u_int id, chunk_t msg)
                    617: {
                    618:        if (msg.len <= VICI_MESSAGE_SIZE_MAX)
                    619:        {
                    620:                entry_selector_t *sel;
                    621:                msg_buf_t *out;
                    622:                entry_t *entry;
                    623: 
                    624:                entry = find_entry(this, NULL, id, FALSE, TRUE);
                    625:                if (entry)
                    626:                {
                    627:                        INIT(out,
                    628:                                .buf = msg,
                    629:                        );
                    630:                        htoun32(out->hdr, msg.len);
                    631: 
                    632:                        array_insert(entry->out, ARRAY_TAIL, out);
                    633:                        if (array_count(entry->out) == 1)
                    634:                        {       /* asynchronously re-enable on_write callback when we get data */
                    635:                                INIT(sel,
                    636:                                        .this = this,
                    637:                                        .id = entry->id,
                    638:                                );
                    639:                                lib->processor->queue_job(lib->processor,
                    640:                                                        (job_t*)callback_job_create(enable_writer,
                    641:                                                                                                                sel, free, NULL));
                    642:                        }
                    643:                        put_entry(this, entry, FALSE, TRUE);
                    644:                }
                    645:                else
                    646:                {
                    647:                        DBG1(DBG_CFG, "vici connection %u unknown", id);
                    648:                        chunk_clear(&msg);
                    649:                }
                    650:        }
                    651:        else
                    652:        {
                    653:                DBG1(DBG_CFG, "vici message size %zu exceeds maximum size of %u, "
                    654:                         "discarded", msg.len, VICI_MESSAGE_SIZE_MAX);
                    655:                chunk_clear(&msg);
                    656:        }
                    657: }
                    658: 
                    659: METHOD(vici_socket_t, destroy, void,
                    660:        private_vici_socket_t *this)
                    661: {
                    662:        DESTROY_IF(this->service);
                    663:        this->connections->destroy_function(this->connections, destroy_entry);
                    664:        this->mutex->destroy(this->mutex);
                    665:        free(this);
                    666: }
                    667: 
                    668: /*
                    669:  * see header file
                    670:  */
                    671: vici_socket_t *vici_socket_create(char *uri, vici_inbound_cb_t inbound,
                    672:                                                                  vici_connect_cb_t connect,
                    673:                                                                  vici_disconnect_cb_t disconnect, void *user)
                    674: {
                    675:        private_vici_socket_t *this;
                    676: 
                    677:        INIT(this,
                    678:                .public = {
                    679:                        .send = _send_,
                    680:                        .destroy = _destroy,
                    681:                },
                    682:                .mutex = mutex_create(MUTEX_TYPE_DEFAULT),
                    683:                .connections = linked_list_create(),
                    684:                .inbound = inbound,
                    685:                .connect = connect,
                    686:                .disconnect = disconnect,
                    687:                .user = user,
                    688:        );
                    689: 
                    690:        this->service = lib->streams->create_service(lib->streams, uri, 3);
                    691:        if (!this->service)
                    692:        {
                    693:                DBG1(DBG_CFG, "creating vici socket failed");
                    694:                destroy(this);
                    695:                return NULL;
                    696:        }
                    697:        this->service->on_accept(this->service, on_accept, this,
                    698:                                                         JOB_PRIO_CRITICAL, 0);
                    699: 
                    700:        return &this->public;
                    701: }

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>