Annotation of embedaddon/strongswan/src/libstrongswan/plugins/ntru/ntru_private_key.c, revision 1.1.1.1

1.1       misho       1: /*
                      2:  * Copyright (C) 2014-2016 Andreas Steffen
                      3:  * HSR Hochschule fuer Technik Rapperswil
                      4:  *
                      5:  * Copyright (C) 2009-2013  Security Innovation
                      6:  *
                      7:  * This program is free software; you can redistribute it and/or modify it
                      8:  * under the terms of the GNU General Public License as published by the
                      9:  * Free Software Foundation; either version 2 of the License, or (at your
                     10:  * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
                     11:  *
                     12:  * This program is distributed in the hope that it will be useful, but
                     13:  * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
                     14:  * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
                     15:  * for more details.
                     16:  */
                     17: 
                     18: #include "ntru_private_key.h"
                     19: #include "ntru_trits.h"
                     20: #include "ntru_poly.h"
                     21: #include "ntru_convert.h"
                     22: 
                     23: #include <utils/debug.h>
                     24: #include <utils/test.h>
                     25: 
                     26: typedef struct private_ntru_private_key_t private_ntru_private_key_t;
                     27: 
                     28: /**
                     29:  * Private data of an ntru_private_key_t object.
                     30:  */
                     31: struct private_ntru_private_key_t {
                     32: 
                     33:        /**
                     34:         * Public ntru_private_key_t interface.
                     35:         */
                     36:        ntru_private_key_t public;
                     37: 
                     38:        /**
                     39:         * NTRU Parameter Set
                     40:         */
                     41:        const ntru_param_set_t *params;
                     42: 
                     43:        /**
                     44:         * Polynomial F which is the private key
                     45:         */
                     46:        ntru_poly_t *privkey;
                     47: 
                     48:        /**
                     49:         * Polynomial h which is the public key
                     50:         */
                     51:        uint16_t *pubkey;
                     52: 
                     53:        /**
                     54:         * Encoding of the private key
                     55:         */
                     56:        chunk_t encoding;
                     57: 
                     58:        /**
                     59:         * Deterministic Random Bit Generator
                     60:         */
                     61:        drbg_t *drbg;
                     62: 
                     63: };
                     64: 
                     65: METHOD(ntru_private_key_t, get_id, ntru_param_set_id_t,
                     66:        private_ntru_private_key_t *this)
                     67: {
                     68:        return this->params->id;
                     69: }
                     70: 
                     71: METHOD(ntru_private_key_t, get_public_key, ntru_public_key_t*,
                     72:        private_ntru_private_key_t *this)
                     73: {
                     74:        return ntru_public_key_create(this->drbg, this->params, this->pubkey);
                     75: }
                     76: 
                     77: /**
                     78:  * Generate NTRU encryption private key encoding
                     79:  */
                     80: static void generate_encoding(private_ntru_private_key_t *this)
                     81: {
                     82:        size_t pubkey_len, privkey_len, privkey_trits_len, privkey_indices_len;
                     83:        int privkey_pack_type;
                     84:        uint16_t *indices;
                     85:        uint8_t *trits;
                     86:        u_char *enc;
                     87: 
                     88:        /* compute public key length encoded as packed coefficients */
                     89:        pubkey_len =  (this->params->N * this->params->q_bits + 7) / 8;
                     90: 
                     91:        /* compute private key length encoded as packed trits coefficients */
                     92:        privkey_trits_len = (this->params->N + 4) / 5;
                     93: 
                     94:        /* compute private key length encoded as packed indices */
                     95:        privkey_indices_len = (this->privkey->get_size(this->privkey) *
                     96:                                                   this->params->N_bits + 7) / 8;
                     97: 
                     98:        if (this->params->is_product_form ||
                     99:                privkey_indices_len <= privkey_trits_len)
                    100:        {
                    101:                privkey_pack_type = NTRU_KEY_PACKED_INDICES;
                    102:                privkey_len = privkey_indices_len;
                    103:        }
                    104:        else
                    105:        {
                    106:                privkey_pack_type = NTRU_KEY_PACKED_TRITS;
                    107:                privkey_len = privkey_trits_len;
                    108:        }
                    109: 
                    110:        /* allocate memory for private key encoding */
                    111:        this->encoding = chunk_alloc(2 + NTRU_OID_LEN + pubkey_len + privkey_len);
                    112:        enc = this->encoding.ptr;
                    113: 
                    114:        /* format header and packed public key */
                    115:        *enc++ = NTRU_PRIVKEY_DEFAULT_TAG;
                    116:        *enc++ = NTRU_OID_LEN;
                    117:        memcpy(enc, this->params->oid, NTRU_OID_LEN);
                    118:        enc += NTRU_OID_LEN;
                    119:        ntru_elements_2_octets(this->params->N, this->pubkey,
                    120:                                                   this->params->q_bits, enc);
                    121:        enc += pubkey_len;
                    122: 
                    123:        /* add packed private key */
                    124:        indices = this->privkey->get_indices(this->privkey);
                    125: 
                    126:        if (privkey_pack_type == NTRU_KEY_PACKED_TRITS)
                    127:        {
                    128:                /* encode private key as packed trits */
                    129:                trits = malloc(this->params->N);
                    130:                ntru_indices_2_packed_trits(indices, this->params->dF_r,
                    131:                                                        this->params->dF_r, this->params->N, trits, enc);
                    132:                memwipe(trits, this->params->N);
                    133:                free(trits);
                    134:        }
                    135:        else
                    136:        {
                    137:                /* encode private key as packed indices */
                    138:                ntru_elements_2_octets(this->privkey->get_size(this->privkey),
                    139:                                                           indices, this->params->N_bits, enc);
                    140:        }
                    141: }
                    142: 
                    143: METHOD(ntru_private_key_t, get_encoding, chunk_t,
                    144:        private_ntru_private_key_t *this)
                    145: {
                    146:        return this->encoding;
                    147: }
                    148: 
                    149: /**
                    150:  * Checks that the number of 0, +1, and -1 trinary ring elements meet or exceed
                    151:  * a minimum weight.
                    152:  *
                    153:  * @param N                    degree of polynomial
                    154:  * @param t                    array of trinary ring elements
                    155:  * @param min_wt       minimum weight
                    156:  * @return                     TRUE if minimum weight met or exceeded
                    157:  */
                    158: bool ntru_check_min_weight(uint16_t N, uint8_t  *t, uint16_t min_wt)
                    159: {
                    160:        uint16_t wt[3];
                    161:        bool success;
                    162:        int i;
                    163: 
                    164:        wt[0] = wt[1] = wt[2] = 0;
                    165: 
                    166:        for (i = 0; i < N; i++)
                    167:        {
                    168:                ++wt[t[i]];
                    169:        }
                    170:        success = (wt[0] >= min_wt) && (wt[1] >= min_wt) && (wt[2] >= min_wt);
                    171: 
                    172:        DBG2(DBG_LIB, "minimum weight = %u, so -1: %u, 0: %u, +1: %u is %sok",
                    173:                                   min_wt, wt[2], wt[0], wt[1], success ? "" : "not ");
                    174: 
                    175:        return success;
                    176: }
                    177: 
                    178: METHOD(ntru_private_key_t, decrypt, bool,
                    179:        private_ntru_private_key_t *this, chunk_t ciphertext, chunk_t *plaintext)
                    180: {
                    181:        ext_out_function_t alg;
                    182:        size_t t_len, seed1_len, seed2_len;
                    183:        uint16_t *t1, *t2, *t = NULL;
                    184:     uint16_t mod_q_mask, q_mod_p, cmprime_len, cm_len = 0, num_zeros;
                    185:        uint8_t *Mtrin, *M, *cm, *mask_trits, *ptr;
                    186:        int16_t m1 = 0;
                    187:        chunk_t seed = chunk_empty;
                    188:        ntru_trits_t *mask;
                    189:        ntru_poly_t *r_poly;
                    190:        bool msg_rep_good, success = TRUE;
                    191:        int i;
                    192: 
                    193:        *plaintext = chunk_empty;
                    194: 
                    195:        if (ciphertext.len != (this->params->N * this->params->q_bits + 7) / 8)
                    196:        {
                    197:                DBG1(DBG_LIB, "wrong NTRU ciphertext length");
                    198:                return FALSE;
                    199:        }
                    200: 
                    201:        /* allocate temporary array t */
                    202:        t_len  = 2 * this->params->N * sizeof(uint16_t);
                    203:        t = malloc(t_len);
                    204:        t1 = t;
                    205:        t2 = t + this->params->N;
                    206:        Mtrin = (uint8_t *)t1;
                    207:        M = Mtrin + this->params->N;
                    208: 
                    209:        /* set MGF1 algorithm type based on security strength */
                    210:        alg = (this->params->sec_strength_len <= 20) ? XOF_MGF1_SHA1 :
                    211:                                                                                                   XOF_MGF1_SHA256;
                    212: 
                    213:        /* set constants */
                    214:        mod_q_mask = this->params->q - 1;
                    215:        q_mod_p = this->params->q % 3;
                    216: 
                    217:     /* unpack the ciphertext */
                    218:     ntru_octets_2_elements(ciphertext.len, ciphertext.ptr,
                    219:                                                   this->params->q_bits, t2);
                    220: 
                    221:        /* form cm':
                    222:         *  F * e
                    223:         *  A = e * (1 + pF) mod q = e + pFe mod q
                    224:         *  a = A in the range [-q/2, q/2)
                    225:         *  cm' = a mod p
                    226:         */
                    227:        this->privkey->ring_mult(this->privkey, t2, t1);
                    228: 
                    229:        cmprime_len = this->params->N;
                    230:        if (this->params->is_product_form)
                    231:        {
                    232:                --cmprime_len;
                    233:                for (i = 0; i < cmprime_len; i++)
                    234:                {
                    235:                        t1[i] = (t2[i] + 3 * t1[i]) & mod_q_mask;
                    236:                        if (t1[i] >= (this->params->q / 2))
                    237:                        {
                    238:                                t1[i] -= q_mod_p;
                    239:                        }
                    240:                        Mtrin[i] = (uint8_t)(t1[i] % 3);
                    241:                        if (Mtrin[i] == 1)
                    242:                        {
                    243:                                ++m1;
                    244:                        }
                    245:                        else if (Mtrin[i] == 2)
                    246:                        {
                    247:                                --m1;
                    248:                        }
                    249:                }
                    250:        }
                    251:        else
                    252:        {
                    253:                for (i = 0; i < cmprime_len; i++)
                    254:                {
                    255:                        t1[i] = (t2[i] + 3 * t1[i]) & mod_q_mask;
                    256:                        if (t1[i] >= (this->params->q / 2))
                    257:                        {
                    258:                                t1[i] -= q_mod_p;
                    259:                        }
                    260:                        Mtrin[i] = (uint8_t)(t1[i] % 3);
                    261:                }
                    262:        }
                    263: 
                    264:     /**
                    265:         * check that the candidate message representative meets
                    266:      * minimum weight requirements
                    267:      */
                    268:        if (this->params->is_product_form)
                    269:        {
                    270:                msg_rep_good = (abs(m1) <= this->params->min_msg_rep_wt);
                    271:        }
                    272:        else
                    273:        {
                    274:                msg_rep_good = ntru_check_min_weight(cmprime_len, Mtrin,
                    275:                                                                                         this->params->min_msg_rep_wt);
                    276:        }
                    277:        if (!msg_rep_good)
                    278:        {
                    279:                DBG1(DBG_LIB, "decryption failed due to insufficient minimum weight");
                    280:                success = FALSE;
                    281:        }
                    282: 
                    283:        /* form cR = e - cm' mod q */
                    284:        for (i = 0; i < cmprime_len; i++)
                    285:        {
                    286:                if (Mtrin[i] == 1)
                    287:                {
                    288:                        t2[i] = (t2[i] - 1) & mod_q_mask;
                    289:                }
                    290:                else if (Mtrin[i] == 2)
                    291:                {
                    292:                        t2[i] = (t2[i] + 1) & mod_q_mask;
                    293:                }
                    294:        }
                    295:        if (this->params->is_product_form)
                    296:        {
                    297:                t2[i] = (t2[i] + m1) & mod_q_mask;
                    298:        }
                    299: 
                    300:        /* allocate memory for the larger of the two seeds */
                    301:        seed1_len = (this->params->N + 3)/4;
                    302:        seed2_len = 3 + 2*this->params->sec_strength_len + this->params->m_len_max;
                    303:        seed = chunk_alloc(max(seed1_len, seed2_len));
                    304:        seed.len = seed1_len;
                    305: 
                    306:        /* form cR mod 4 */
                    307:        ntru_coeffs_mod4_2_octets(this->params->N, t2, seed.ptr);
                    308: 
                    309:        /* form mask */
                    310:        mask = ntru_trits_create(this->params->N, alg, seed);
                    311:        if (!mask)
                    312:        {
                    313:                DBG1(DBG_LIB, "mask creation failed");
                    314:                success = FALSE;
                    315:                goto err;
                    316:        }
                    317: 
                    318:        mask_trits = mask->get_trits(mask);
                    319: 
                    320:        /* form cMtrin by subtracting mask from cm', mod p */
                    321:        for (i = 0; i < cmprime_len; i++)
                    322:        {
                    323:                Mtrin[i] -=  mask_trits[i];
                    324:                if (Mtrin[i] >= 3)
                    325:                {
                    326:                        Mtrin[i] += 3;
                    327:                }
                    328:        }
                    329:        mask->destroy(mask);
                    330: 
                    331:        if (this->params->is_product_form)
                    332:        {
                    333:                /* set the last trit to zero since that's what it was, and
                    334:                 * because it can't be calculated from (cm' - mask) since
                    335:                 * we don't have the correct value for the last cm' trit
                    336:                 */
                    337:                Mtrin[i] = 0;
                    338:        }
                    339: 
                    340:        /* convert cMtrin to cM (Mtrin to Mbin) */
                    341:        if (!ntru_trits_2_bits(Mtrin, this->params->N, M))
                    342:        {
                    343:                success = FALSE;
                    344:                goto err;
                    345:        }
                    346: 
                    347:        /* skip the random padding */
                    348:        ptr = M + this->params->sec_strength_len;
                    349: 
                    350:        /* validate the padded message cM and copy cm to m_buf */
                    351:        if (this->params->m_len_len == 2)
                    352:        {
                    353:                cm_len = (uint16_t)(*ptr++) << 16;
                    354:        }
                    355:        cm_len |= (uint16_t)(*ptr++);
                    356: 
                    357:        if (cm_len > this->params->m_len_max)
                    358:        {
                    359:                cm_len = this->params->m_len_max;
                    360:                DBG1(DBG_LIB, "NTRU message length is larger than maximum length");
                    361:                success = FALSE;
                    362:        }
                    363:        cm = ptr;
                    364:        ptr += cm_len;
                    365: 
                    366:        /* check if the remaining padding consists of zeros */
                    367:        num_zeros = this->params->m_len_max - cm_len + 1;
                    368:        for (i = 0; i < num_zeros; i++)
                    369:        {
                    370:                if (ptr[i] != 0)
                    371:                {
                    372:                        DBG1(DBG_LIB, "non-zero trailing padding detected");
                    373:                        success = FALSE;
                    374:                        break;
                    375:                }
                    376:        }
                    377: 
                    378:        /* form sData (OID || m || b || hTrunc) */
                    379:        ptr = seed.ptr;
                    380:        memcpy(ptr, this->params->oid, 3);
                    381:        ptr += 3;
                    382:        memcpy(ptr, cm, cm_len);
                    383:        ptr += cm_len;
                    384:        memcpy(ptr, M, this->params->sec_strength_len);
                    385:        ptr += this->params->sec_strength_len;
                    386:        memcpy(ptr, this->encoding.ptr + 2 + NTRU_OID_LEN,
                    387:                   this->params->sec_strength_len);
                    388:        ptr += this->params->sec_strength_len;
                    389:        seed.len = ptr - seed.ptr;
                    390: 
                    391:        /* generate cr */
                    392:        DBG2(DBG_LIB, "generate polynomial r");
                    393:        r_poly = ntru_poly_create_from_seed(alg, seed, this->params->c_bits,
                    394:                                                this->params->N, this->params->q, this->params->dF_r,
                    395:                                                this->params->dF_r, this->params->is_product_form);
                    396:        if (!r_poly)
                    397:        {
                    398:                success = FALSE;
                    399:                goto err;
                    400:        }
                    401: 
                    402:        /* output plaintext in allocated chunk */
                    403:        *plaintext = chunk_clone(chunk_create(cm, cm_len));
                    404: 
                    405:        /* form cR' = h * cr */
                    406:        r_poly->ring_mult(r_poly, this->pubkey, t1);
                    407:        r_poly->destroy(r_poly);
                    408: 
                    409:        /* compare cR' to cR */
                    410:        for (i = 0; i < this->params->N; i++)
                    411:        {
                    412:                if (t[i] != t2[i])
                    413:                {
                    414:                        DBG1(DBG_LIB, "cR' does not equal cR'");
                    415:                        chunk_clear(plaintext);
                    416:                        success = FALSE;
                    417:                        break;
                    418:                }
                    419:        }
                    420:        memwipe(t, t_len);
                    421: 
                    422: err:
                    423:        /* cleanup */
                    424:        chunk_clear(&seed);
                    425:        free(t);
                    426: 
                    427:        return success;
                    428: }
                    429: 
                    430: METHOD(ntru_private_key_t, destroy, void,
                    431:        private_ntru_private_key_t *this)
                    432: {
                    433:        DESTROY_IF(this->privkey);
                    434:        this->drbg->destroy(this->drbg);
                    435:        chunk_clear(&this->encoding);
                    436:        free(this->pubkey);
                    437:        free(this);
                    438: }
                    439: 
                    440: /**
                    441:  * Multiplies ring element (polynomial) "a" by ring element (polynomial) "b"
                    442:  * to produce ring element (polynomial) "c" in (Z/qZ)[X]/(X^N - 1).
                    443:  * This is a convolution operation.
                    444:  *
                    445:  * Ring element "b" has coefficients in the range [0,N).
                    446:  *
                    447:  * This assumes q is 2^r where 8 < r < 16, so that overflow of the sum
                    448:  * beyond 16 bits does not matter.
                    449:  *
                    450:  * @param a            polynomial a
                    451:  * @param b            polynomial b
                    452:  * @param N            no. of coefficients in a, b, c
                    453:  * @param q            large modulus
                    454:  * @param c            polynomial c = a * b
                    455:  */
                    456: static void ring_mult_c(uint16_t *a, uint16_t *b, uint16_t N, uint16_t q,
                    457:                                            uint16_t *c)
                    458: {
                    459:        uint16_t *bptr = b;
                    460:        uint16_t mod_q_mask = q - 1;
                    461:        int i, k;
                    462: 
                    463:        /* c[k] = sum(a[i] * b[k-i]) mod q */
                    464:        memset(c, 0, N * sizeof(uint16_t));
                    465:        for (k = 0; k < N; k++)
                    466:        {
                    467:                i = 0;
                    468:                while (i <= k)
                    469:                {
                    470:                        c[k] += a[i++] * *bptr--;
                    471:                }
                    472:                bptr += N;
                    473:                while (i < N)
                    474:                {
                    475:                        c[k] += a[i++] * *bptr--;
                    476:                }
                    477:                c[k] &= mod_q_mask;
                    478:                ++bptr;
                    479:        }
                    480: }
                    481: 
                    482: /**
                    483:  * Finds the inverse of a polynomial a in (Z/2^rZ)[X]/(X^N - 1).
                    484:  *
                    485:  * This assumes q is 2^r where 8 < r < 16, so that operations mod q can
                    486:  * wait until the end, and only 16-bit arrays need to be used.
                    487:  *
                    488:  * @param a                    polynomial a
                    489:  * @param N                    no. of coefficients in a
                    490:  * @param q                    large modulus
                    491:  * @param t                    temporary buffer of size 2N elements
                    492:  * @param a_inv        polynomial for inverse of a
                    493:  */
                    494: static bool ring_inv(uint16_t *a, uint16_t N, uint16_t q, uint16_t *t,
                    495:                                         uint16_t *a_inv)
                    496: {
                    497:        uint8_t *b = (uint8_t *)t;
                    498:        uint8_t *c = b + N;
                    499:        uint8_t *f = c + N;
                    500:        uint8_t *g = (uint8_t *)a_inv;
                    501:        uint16_t *t2 = t + N;
                    502:        uint16_t deg_b, deg_c, deg_f, deg_g;
                    503:     bool done = FALSE;
                    504:     int i, j, k = 0;
                    505: 
                    506:        /* form a^-1 in (Z/2Z)[X]/X^N - 1) */
                    507:        memset(b, 0, 2 * N);                                    /* clear to init b, c */
                    508: 
                    509:        /* b(X) = 1 */
                    510:        b[0] = 1;
                    511:        deg_b = 0;
                    512: 
                    513:        /* c(X) = 0 (cleared above) */
                    514:        deg_c = 0;
                    515: 
                    516:        /* f(X) = a(X) mod 2 */
                    517:        for (i = 0; i < N; i++)
                    518:        {
                    519:                f[i] = (uint8_t)(a[i] & 1);
                    520:        }
                    521:        deg_f = N - 1;
                    522: 
                    523:        /* g(X) = X^N - 1 */
                    524:        g[0] = 1;
                    525:        memset(g + 1, 0, N - 1);
                    526:        g[N] = 1;
                    527:        deg_g = N;
                    528: 
                    529:        /* until f(X) = 1 */
                    530:        while (!done)
                    531:        {
                    532:                /* while f[0] = 0, f(X) /= X, c(X) *= X, k++ */
                    533:                for (i = 0; (i <= deg_f) && (f[i] == 0); ++i);
                    534: 
                    535:                if (i > deg_f)
                    536:                {
                    537:                        return FALSE;
                    538:                }
                    539:                if (i)
                    540:                {
                    541:                        f = f + i;
                    542:                        deg_f = deg_f - i;
                    543:                        deg_c = deg_c + i;
                    544:                        for (j = deg_c; j >= i; j--)
                    545:                        {
                    546:                                c[j] = c[j-i];
                    547:                        }
                    548:                        for (j = 0; j < i; j++)
                    549:                        {
                    550:                                c[j] = 0;
                    551:                        }
                    552:                        k = k + i;
                    553:                }
                    554: 
                    555:                /* adjust degree of f(X) if the highest coefficients are zero
                    556:                 * Note: f[0] = 1 from above so the loop will terminate.
                    557:                 */
                    558:                while (f[deg_f] == 0)
                    559:                {
                    560:                        --deg_f;
                    561:                }
                    562: 
                    563:                /* if f(X) = 1, done
                    564:                 * Note: f[0] = 1 from above, so only check the x term and up
                    565:                 */
                    566:                for (i = 1; (i <= deg_f) && (f[i] == 0); ++i);
                    567: 
                    568:                if (i > deg_f)
                    569:                {
                    570:                        done = TRUE;
                    571:                        break;
                    572:                }
                    573: 
                    574:                /* if deg_f < deg_g, f <-> g, b <-> c */
                    575:                if (deg_f < deg_g)
                    576:                {
                    577:                        uint8_t *x;
                    578: 
                    579:                        x = f;
                    580:                        f = g;
                    581:                        g = x;
                    582:                        deg_f ^= deg_g;
                    583:                        deg_g ^= deg_f;
                    584:                        deg_f ^= deg_g;
                    585:                        x = b;
                    586:                        b = c;
                    587:                        c = x;
                    588:                        deg_b ^= deg_c;
                    589:                        deg_c ^= deg_b;
                    590:                        deg_b ^= deg_c;
                    591:                }
                    592: 
                    593:                /* f(X) += g(X), b(X) += c(X) */
                    594:                for (i = 0; i <= deg_g; i++)
                    595:                {
                    596:                        f[i] ^= g[i];
                    597:                }
                    598:                if (deg_c > deg_b)
                    599:                {
                    600:                        deg_b = deg_c;
                    601:                }
                    602:                for (i = 0; i <= deg_c; i++)
                    603:                {
                    604:                        b[i] ^= c[i];
                    605:                }
                    606:        }
                    607: 
                    608:        /* a^-1 in (Z/2Z)[X]/(X^N - 1) = b(X) shifted left k coefficients */
                    609:        j = 0;
                    610:        if (k >= N)
                    611:        {
                    612:                k = k - N;
                    613:        }
                    614:        for (i = k; i < N; i++)
                    615:        {
                    616:                a_inv[j++] = (uint16_t)(b[i]);
                    617:        }
                    618:        for (i = 0; i < k; i++)
                    619:        {
                    620:                a_inv[j++] = (uint16_t)(b[i]);
                    621:        }
                    622: 
                    623:        /* lift a^-1 in (Z/2Z)[X]/(X^N - 1) to a^-1 in (Z/qZ)[X]/(X^N -1) */
                    624:     for (j = 0; j < 4; ++j)                            /* assumes 256 < q <= 65536 */
                    625:        {
                    626:                /* a^-1 = a^-1 * (2 - a * a^-1) mod q */
                    627:                memcpy(t2, a_inv, N * sizeof(uint16_t));
                    628:                ring_mult_c(a, t2, N, q, t);
                    629:                for (i = 0; i < N; ++i)
                    630:                {
                    631:                        t[i] = q - t[i];
                    632:                }
                    633:                t[0] = t[0] + 2;
                    634:                ring_mult_c(t2, t, N, q, a_inv);
                    635:        }
                    636: 
                    637:        return TRUE;
                    638: }
                    639: 
                    640: /*
                    641:  * Described in header.
                    642:  */
                    643: ntru_private_key_t *ntru_private_key_create(drbg_t *drbg,
                    644:                                                                                        const ntru_param_set_t *params)
                    645: {
                    646:        private_ntru_private_key_t *this;
                    647:        size_t t_len;
                    648:        uint16_t *t1, *t2, *t = NULL;
                    649:        uint16_t mod_q_mask;
                    650:     ext_out_function_t alg;
                    651:        ntru_poly_t *g_poly;
                    652:        chunk_t seed;
                    653:        int i;
                    654: 
                    655:        INIT(this,
                    656:                .public = {
                    657:                        .get_id = _get_id,
                    658:                        .get_public_key = _get_public_key,
                    659:                        .get_encoding = _get_encoding,
                    660:                        .decrypt = _decrypt,
                    661:                        .destroy = _destroy,
                    662:                },
                    663:                .params = params,
                    664:                .pubkey = malloc(params->N * sizeof(uint16_t)),
                    665:                .drbg = drbg->get_ref(drbg),
                    666:        );
                    667: 
                    668:        /* set hash algorithm and seed length based on security strength */
                    669:        alg = (params->sec_strength_len <= 20) ? XOF_MGF1_SHA1 :
                    670:                                                                                         XOF_MGF1_SHA256;
                    671:        seed =chunk_alloc(params->sec_strength_len + 8);
                    672: 
                    673:        /* get random seed for generating trinary F as a list of indices */
                    674:        if (!drbg->generate(drbg, seed.len, seed.ptr))
                    675:        {
                    676:                goto err;
                    677:        }
                    678: 
                    679:        DBG2(DBG_LIB, "generate polynomial F");
                    680:        this->privkey = ntru_poly_create_from_seed(alg, seed, params->c_bits,
                    681:                                                                                           params->N, params->q,
                    682:                                                                                           params->dF_r, params->dF_r,
                    683:                                                                                           params->is_product_form);
                    684:        if (!this->privkey)
                    685:        {
                    686:                goto err;
                    687:        }
                    688: 
                    689:        /* allocate temporary array t */
                    690:        t_len = 3 * params->N * sizeof(uint16_t);
                    691:        t = malloc(t_len);
                    692:        t1 = t + 2 * params->N;
                    693: 
                    694:        /* extend sparse private key polynomial f to N array elements */
                    695:        this->privkey->get_array(this->privkey, t1);
                    696: 
                    697:        /* set mask for large modulus */
                    698:        mod_q_mask = params->q - 1;
                    699: 
                    700:        /* form f = 1 + pF */
                    701:        for (i = 0; i < params->N; i++)
                    702:        {
                    703:                t1[i] = (t1[i] * 3) & mod_q_mask;
                    704:        }
                    705:        t1[0] = (t1[0] + 1) & mod_q_mask;
                    706: 
                    707:        /* use the public key array as a temporary buffer */
                    708:        t2 = this->pubkey;
                    709: 
                    710:        /* find f^-1 in (Z/qZ)[X]/(X^N - 1) */
                    711:        if (!ring_inv(t1, params->N, params->q, t, t2))
                    712:        {
                    713:                goto err;
                    714:        }
                    715: 
                    716:        /* get random seed for generating trinary g as a list of indices */
                    717:        if (!drbg->generate(drbg, seed.len, seed.ptr))
                    718:        {
                    719:                goto err;
                    720:        }
                    721: 
                    722:        DBG2(DBG_LIB, "generate polynomial g");
                    723:        g_poly = ntru_poly_create_from_seed(alg, seed, params->c_bits,
                    724:                                                                                params->N, params->q, params->dg + 1,
                    725:                                                                                params->dg, FALSE);
                    726:        if (!g_poly)
                    727:        {
                    728:                goto err;
                    729:        }
                    730: 
                    731:        /* compute public key polynomial h = p * (f^-1 * g) mod q */
                    732:        g_poly->ring_mult(g_poly, t2, t2);
                    733:        g_poly->destroy(g_poly);
                    734: 
                    735:        for (i = 0; i < params->N; i++)
                    736:        {
                    737:                this->pubkey[i] = (t2[i] * 3) & mod_q_mask;
                    738:        }
                    739: 
                    740:        /* cleanup temporary storage */
                    741:        chunk_clear(&seed);
                    742:        memwipe(t, t_len);
                    743:        free(t);
                    744: 
                    745:        /* generate private key encoding */
                    746:        generate_encoding(this);
                    747: 
                    748:        return &this->public;
                    749: 
                    750: err:
                    751:        chunk_free(&seed);
                    752:        free(t);
                    753:        destroy(this);
                    754: 
                    755:        return NULL;
                    756: }
                    757: 
                    758: /*
                    759:  * Described in header.
                    760:  */
                    761: ntru_private_key_t *ntru_private_key_create_from_data(drbg_t *drbg,
                    762:                                                                                                          chunk_t data)
                    763: {
                    764:        private_ntru_private_key_t *this;
                    765:        size_t header_len, pubkey_packed_len, privkey_packed_len;
                    766:        size_t privkey_packed_trits_len, privkey_packed_indices_len;
                    767:        uint8_t *privkey_packed, tag;
                    768:        uint16_t *indices, dF;
                    769:        const ntru_param_set_t *params;
                    770: 
                    771:        header_len = 2 + NTRU_OID_LEN;
                    772: 
                    773:        /* check the NTRU public key header format */
                    774:        if (data.len < header_len ||
                    775:                !(data.ptr[0] == NTRU_PRIVKEY_DEFAULT_TAG ||
                    776:                  data.ptr[0] == NTRU_PRIVKEY_TRITS_TAG ||
                    777:                  data.ptr[0] == NTRU_PRIVKEY_INDICES_TAG) ||
                    778:                data.ptr[1] != NTRU_OID_LEN)
                    779:        {
                    780:                DBG1(DBG_LIB, "loaded NTRU private key with invalid header");
                    781:                return NULL;
                    782:        }
                    783:        tag = data.ptr[0];
                    784:        params = ntru_param_set_get_by_oid(data.ptr + 2);
                    785: 
                    786:        if (!params)
                    787:        {
                    788:                DBG1(DBG_LIB, "loaded NTRU private key with unknown OID");
                    789:                return NULL;
                    790:        }
                    791: 
                    792:        pubkey_packed_len = (params->N * params->q_bits + 7) / 8;
                    793:        privkey_packed_trits_len = (params->N + 4) / 5;
                    794: 
                    795:        /* check packing type for product-form private keys */
                    796:        if (params->is_product_form &&  tag == NTRU_PRIVKEY_TRITS_TAG)
                    797:        {
                    798:                DBG1(DBG_LIB, "a product-form NTRU private key cannot be trits-encoded");
                    799:                return NULL;
                    800:        }
                    801: 
                    802:        /* set packed-key length for packed indices */
                    803:        if (params->is_product_form)
                    804:        {
                    805:                dF = (uint16_t)((params->dF_r & 0xff) +           /* df1 */
                    806:                                           ((params->dF_r >>  8) & 0xff) +    /* df2 */
                    807:                                           ((params->dF_r >> 16) & 0xff));    /* df3 */
                    808:        }
                    809:        else
                    810:        {
                    811:                dF = (uint16_t)params->dF_r;
                    812:        }
                    813:        privkey_packed_indices_len = (2 * dF * params->N_bits + 7) / 8;
                    814: 
                    815:        /* set private-key packing type if defaulted */
                    816:        if (tag == NTRU_PRIVKEY_DEFAULT_TAG)
                    817:        {
                    818:                if (params->is_product_form ||
                    819:             privkey_packed_indices_len <= privkey_packed_trits_len)
                    820:                {
                    821:                        tag = NTRU_PRIVKEY_INDICES_TAG;
                    822:                }
                    823:                else
                    824:                {
                    825:                        tag = NTRU_PRIVKEY_TRITS_TAG;
                    826:                }
                    827:        }
                    828:        privkey_packed_len = (tag == NTRU_PRIVKEY_TRITS_TAG) ?
                    829:                                 privkey_packed_trits_len : privkey_packed_indices_len;
                    830: 
                    831:        if (data.len < header_len + pubkey_packed_len + privkey_packed_len)
                    832:        {
                    833:                DBG1(DBG_LIB, "loaded NTRU private key with wrong packed key size");
                    834:                return NULL;
                    835:        }
                    836: 
                    837:        INIT(this,
                    838:                .public = {
                    839:                        .get_id = _get_id,
                    840:                        .get_public_key = _get_public_key,
                    841:                        .get_encoding = _get_encoding,
                    842:                        .decrypt = _decrypt,
                    843:                        .destroy = _destroy,
                    844:                },
                    845:                .params = params,
                    846:                .pubkey = malloc(params->N * sizeof(uint16_t)),
                    847:                .encoding = chunk_clone(data),
                    848:                .drbg = drbg->get_ref(drbg),
                    849:        );
                    850: 
                    851:        /* unpack the encoded public key */
                    852:        ntru_octets_2_elements(pubkey_packed_len, data.ptr + header_len,
                    853:                                                   params->q_bits, this->pubkey);
                    854: 
                    855:        /* allocate temporary memory for indices */
                    856:        indices = malloc(2 * dF * sizeof(uint16_t));
                    857: 
                    858:        /* unpack the private key */
                    859:        privkey_packed = data.ptr + header_len + pubkey_packed_len;
                    860:        if (tag == NTRU_PRIVKEY_TRITS_TAG)
                    861:        {
                    862:                ntru_packed_trits_2_indices(privkey_packed, params->N,
                    863:                                                                        indices, indices + dF);
                    864:     }
                    865:        else
                    866:        {
                    867:         ntru_octets_2_elements(privkey_packed_indices_len, privkey_packed,
                    868:                                                           params->N_bits, indices);
                    869:     }
                    870:        this->privkey = ntru_poly_create_from_data(indices, params->N, params->q,
                    871:                                                                                           params->dF_r, params->dF_r,
                    872:                                                                                           params->is_product_form);
                    873: 
                    874:        /* cleanup */
                    875:        memwipe(indices, 2 * dF * sizeof(uint16_t));
                    876:        free(indices);
                    877: 
                    878:        return &this->public;
                    879: }
                    880: 
                    881: EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_private_key_create);
                    882: 
                    883: EXPORT_FUNCTION_FOR_TESTS(ntru, ntru_private_key_create_from_data);

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>