Annotation of embedaddon/strongswan/scripts/tls_test.c, revision 1.1.1.1

1.1       misho       1: /*
                      2:  * Copyright (C) 2010 Martin Willi
                      3:  * Copyright (C) 2010 revosec AG
                      4:  *
                      5:  * This program is free software; you can redistribute it and/or modify it
                      6:  * under the terms of the GNU General Public License as published by the
                      7:  * Free Software Foundation; either version 2 of the License, or (at your
                      8:  * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
                      9:  *
                     10:  * This program is distributed in the hope that it will be useful, but
                     11:  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
                     12:  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
                     13:  * for more details.
                     14:  */
                     15: 
                     16: #include <unistd.h>
                     17: #include <stdio.h>
                     18: #include <sys/types.h>
                     19: #include <sys/socket.h>
                     20: #include <getopt.h>
                     21: #include <errno.h>
                     22: #include <string.h>
                     23: 
                     24: #include <library.h>
                     25: #include <utils/debug.h>
                     26: #include <tls_socket.h>
                     27: #include <networking/host.h>
                     28: #include <credentials/sets/mem_cred.h>
                     29: 
                     30: /**
                     31:  * Print usage information
                     32:  */
                     33: static void usage(FILE *out, char *cmd)
                     34: {
                     35:        fprintf(out, "usage:\n");
                     36:        fprintf(out, "  %s --connect <address> --port <port> [--key <key] [--cert <file>]+ [--times <n>]\n", cmd);
                     37:        fprintf(out, "  %s --listen <address> --port <port> --key <key> [--cert <file>]+ [--times <n>]\n", cmd);
                     38: }
                     39: 
                     40: /**
                     41:  * Check, as client, if we have a client certificate with private key
                     42:  */
                     43: static identification_t *find_client_id()
                     44: {
                     45:        identification_t *client = NULL, *keyid;
                     46:        enumerator_t *enumerator;
                     47:        certificate_t *cert;
                     48:        public_key_t *pubkey;
                     49:        private_key_t *privkey;
                     50:        chunk_t chunk;
                     51: 
                     52:        enumerator = lib->credmgr->create_cert_enumerator(lib->credmgr,
                     53:                                                                                        CERT_X509, KEY_ANY, NULL, FALSE);
                     54:        while (enumerator->enumerate(enumerator, &cert))
                     55:        {
                     56:                pubkey = cert->get_public_key(cert);
                     57:                if (pubkey)
                     58:                {
                     59:                        if (pubkey->get_fingerprint(pubkey, KEYID_PUBKEY_SHA1, &chunk))
                     60:                        {
                     61:                                keyid = identification_create_from_encoding(ID_KEY_ID, chunk);
                     62:                                privkey = lib->credmgr->get_private(lib->credmgr,
                     63:                                                                        pubkey->get_type(pubkey), keyid, NULL);
                     64:                                keyid->destroy(keyid);
                     65:                                if (privkey)
                     66:                                {
                     67:                                        client = cert->get_subject(cert);
                     68:                                        client = client->clone(client);
                     69:                                        privkey->destroy(privkey);
                     70:                                }
                     71:                        }
                     72:                        pubkey->destroy(pubkey);
                     73:                }
                     74:                if (client)
                     75:                {
                     76:                        break;
                     77:                }
                     78:        }
                     79:        enumerator->destroy(enumerator);
                     80: 
                     81:        return client;
                     82: }
                     83: 
                     84: /**
                     85:  * Client routine
                     86:  */
                     87: static int run_client(host_t *host, identification_t *server,
                     88:                                          identification_t *client, int times, tls_cache_t *cache)
                     89: {
                     90:        tls_socket_t *tls;
                     91:        int fd, res;
                     92: 
                     93:        while (times == -1 || times-- > 0)
                     94:        {
                     95:                fd = socket(AF_INET, SOCK_STREAM, 0);
                     96:                if (fd == -1)
                     97:                {
                     98:                        DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
                     99:                        return 1;
                    100:                }
                    101:                if (connect(fd, host->get_sockaddr(host),
                    102:                                        *host->get_sockaddr_len(host)) == -1)
                    103:                {
                    104:                        DBG1(DBG_TLS, "connecting to %#H failed: %s", host, strerror(errno));
                    105:                        close(fd);
                    106:                        return 1;
                    107:                }
                    108:                tls = tls_socket_create(FALSE, server, client, fd, cache, TLS_1_2, TRUE);
                    109:                if (!tls)
                    110:                {
                    111:                        close(fd);
                    112:                        return 1;
                    113:                }
                    114:                res = tls->splice(tls, 0, 1) ? 0 : 1;
                    115:                tls->destroy(tls);
                    116:                close(fd);
                    117:                if (res)
                    118:                {
                    119:                        break;
                    120:                }
                    121:        }
                    122:        return res;
                    123: }
                    124: 
                    125: /**
                    126:  * Server routine
                    127:  */
                    128: static int serve(host_t *host, identification_t *server,
                    129:                                 int times, tls_cache_t *cache)
                    130: {
                    131:        tls_socket_t *tls;
                    132:        int fd, cfd;
                    133: 
                    134:        fd = socket(AF_INET, SOCK_STREAM, 0);
                    135:        if (fd == -1)
                    136:        {
                    137:                DBG1(DBG_TLS, "opening socket failed: %s", strerror(errno));
                    138:                return 1;
                    139:        }
                    140:        if (bind(fd, host->get_sockaddr(host),
                    141:                         *host->get_sockaddr_len(host)) == -1)
                    142:        {
                    143:                DBG1(DBG_TLS, "binding to %#H failed: %s", host, strerror(errno));
                    144:                close(fd);
                    145:                return 1;
                    146:        }
                    147:        if (listen(fd, 1) == -1)
                    148:        {
                    149:                DBG1(DBG_TLS, "listen to %#H failed: %m", host, strerror(errno));
                    150:                close(fd);
                    151:                return 1;
                    152:        }
                    153: 
                    154:        while (times == -1 || times-- > 0)
                    155:        {
                    156:                cfd = accept(fd, host->get_sockaddr(host), host->get_sockaddr_len(host));
                    157:                if (cfd == -1)
                    158:                {
                    159:                        DBG1(DBG_TLS, "accept failed: %s", strerror(errno));
                    160:                        close(fd);
                    161:                        return 1;
                    162:                }
                    163:                DBG1(DBG_TLS, "%#H connected", host);
                    164: 
                    165:                tls = tls_socket_create(TRUE, server, NULL, cfd, cache, TLS_1_2, TRUE);
                    166:                if (!tls)
                    167:                {
                    168:                        close(fd);
                    169:                        return 1;
                    170:                }
                    171:                tls->splice(tls, 0, 1);
                    172:                DBG1(DBG_TLS, "%#H disconnected", host);
                    173:                tls->destroy(tls);
                    174:        }
                    175:        close(fd);
                    176: 
                    177:        return 0;
                    178: }
                    179: 
                    180: /**
                    181:  * In-Memory credential set
                    182:  */
                    183: static mem_cred_t *creds;
                    184: 
                    185: /**
                    186:  * Load certificate from file
                    187:  */
                    188: static bool load_certificate(char *filename)
                    189: {
                    190:        certificate_t *cert;
                    191: 
                    192:        cert = lib->creds->create(lib->creds, CRED_CERTIFICATE, CERT_X509,
                    193:                                                          BUILD_FROM_FILE, filename, BUILD_END);
                    194:        if (!cert)
                    195:        {
                    196:                DBG1(DBG_TLS, "loading certificate from '%s' failed", filename);
                    197:                return FALSE;
                    198:        }
                    199:        creds->add_cert(creds, TRUE, cert);
                    200:        return TRUE;
                    201: }
                    202: 
                    203: /**
                    204:  * Load private key from file
                    205:  */
                    206: static bool load_key(char *filename)
                    207: {
                    208:        private_key_t *key;
                    209: 
                    210:        key = lib->creds->create(lib->creds, CRED_PRIVATE_KEY, KEY_RSA,
                    211:                                                          BUILD_FROM_FILE, filename, BUILD_END);
                    212:        if (!key)
                    213:        {
                    214:                DBG1(DBG_TLS, "loading key from '%s' failed", filename);
                    215:                return FALSE;
                    216:        }
                    217:        creds->add_key(creds, key);
                    218:        return TRUE;
                    219: }
                    220: 
                    221: /**
                    222:  * TLS debug level
                    223:  */
                    224: static level_t tls_level = 1;
                    225: 
                    226: static void dbg_tls(debug_t group, level_t level, char *fmt, ...)
                    227: {
                    228:        if ((group == DBG_TLS && level <= tls_level) || level <= 1)
                    229:        {
                    230:                va_list args;
                    231: 
                    232:                va_start(args, fmt);
                    233:                vfprintf(stderr, fmt, args);
                    234:                fprintf(stderr, "\n");
                    235:                va_end(args);
                    236:        }
                    237: }
                    238: 
                    239: /**
                    240:  * Cleanup
                    241:  */
                    242: static void cleanup()
                    243: {
                    244:        lib->credmgr->remove_set(lib->credmgr, &creds->set);
                    245:        creds->destroy(creds);
                    246:        library_deinit();
                    247: }
                    248: 
                    249: /**
                    250:  * Initialize library
                    251:  */
                    252: static void init()
                    253: {
                    254:        library_init(NULL, "tls_test");
                    255: 
                    256:        dbg = dbg_tls;
                    257: 
                    258:        lib->plugins->load(lib->plugins, PLUGINS);
                    259: 
                    260:        creds = mem_cred_create();
                    261:        lib->credmgr->add_set(lib->credmgr, &creds->set);
                    262: 
                    263:        atexit(cleanup);
                    264: }
                    265: 
                    266: int main(int argc, char *argv[])
                    267: {
                    268:        char *address = NULL;
                    269:        bool listen = FALSE;
                    270:        int port = 0, times = -1, res;
                    271:        identification_t *server, *client;
                    272:        tls_cache_t *cache;
                    273:        host_t *host;
                    274: 
                    275:        init();
                    276: 
                    277:        while (TRUE)
                    278:        {
                    279:                struct option long_opts[] = {
                    280:                        {"help",                no_argument,                    NULL,           'h' },
                    281:                        {"connect",             required_argument,              NULL,           'c' },
                    282:                        {"listen",              required_argument,              NULL,           'l' },
                    283:                        {"port",                required_argument,              NULL,           'p' },
                    284:                        {"cert",                required_argument,              NULL,           'x' },
                    285:                        {"key",                 required_argument,              NULL,           'k' },
                    286:                        {"times",               required_argument,              NULL,           't' },
                    287:                        {"debug",               required_argument,              NULL,           'd' },
                    288:                        {0,0,0,0 }
                    289:                };
                    290:                switch (getopt_long(argc, argv, "", long_opts, NULL))
                    291:                {
                    292:                        case EOF:
                    293:                                break;
                    294:                        case 'h':
                    295:                                usage(stdout, argv[0]);
                    296:                                return 0;
                    297:                        case 'x':
                    298:                                if (!load_certificate(optarg))
                    299:                                {
                    300:                                        return 1;
                    301:                                }
                    302:                                continue;
                    303:                        case 'k':
                    304:                                if (!load_key(optarg))
                    305:                                {
                    306:                                        return 1;
                    307:                                }
                    308:                                continue;
                    309:                        case 'l':
                    310:                                listen = TRUE;
                    311:                                /* fall */
                    312:                        case 'c':
                    313:                                if (address)
                    314:                                {
                    315:                                        usage(stderr, argv[0]);
                    316:                                        return 1;
                    317:                                }
                    318:                                address = optarg;
                    319:                                continue;
                    320:                        case 'p':
                    321:                                port = atoi(optarg);
                    322:                                continue;
                    323:                        case 't':
                    324:                                times = atoi(optarg);
                    325:                                continue;
                    326:                        case 'd':
                    327:                                tls_level = atoi(optarg);
                    328:                                continue;
                    329:                        default:
                    330:                                usage(stderr, argv[0]);
                    331:                                return 1;
                    332:                }
                    333:                break;
                    334:        }
                    335:        if (!port || !address)
                    336:        {
                    337:                usage(stderr, argv[0]);
                    338:                return 1;
                    339:        }
                    340:        host = host_create_from_dns(address, 0, port);
                    341:        if (!host)
                    342:        {
                    343:                DBG1(DBG_TLS, "resolving hostname %s failed", address);
                    344:                return 1;
                    345:        }
                    346:        server = identification_create_from_string(address);
                    347:        cache = tls_cache_create(100, 30);
                    348:        if (listen)
                    349:        {
                    350:                res = serve(host, server, times, cache);
                    351:        }
                    352:        else
                    353:        {
                    354:                client = find_client_id();
                    355:                res = run_client(host, server, client, times, cache);
                    356:                DESTROY_IF(client);
                    357:        }
                    358:        cache->destroy(cache);
                    359:        host->destroy(host);
                    360:        server->destroy(server);
                    361:        return res;
                    362: }

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>