Return to ntru_convert.c CVS log | Up to [ELWIX - Embedded LightWeight unIX -] / embedaddon / strongswan / src / libstrongswan / plugins / ntru |
1.1 misho 1: /* 2: * Copyright (C) 2014 Andreas Steffen 3: * HSR Hochschule fuer Technik Rapperswil 4: * 5: * Copyright (C) 2009-2013 Security Innovation 6: * 7: * This program is free software; you can redistribute it and/or modify it 8: * under the terms of the GNU General Public License as published by the 9: * Free Software Foundation; either version 2 of the License, or (at your 10: * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>. 11: * 12: * This program is distributed in the hope that it will be useful, but 13: * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY 14: * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License 15: * for more details. 16: */ 17: 18: #include <stdlib.h> 19: #include <string.h> 20: 21: #include "ntru_convert.h" 22: 23: /** 24: * 3-bit to 2-trit conversion tables: 2 represents -1 25: */ 26: static uint8_t const bits_2_trit1[] = {0, 0, 0, 1, 1, 1, 2, 2}; 27: static uint8_t const bits_2_trit2[] = {0, 1, 2, 0, 1, 2, 0, 1}; 28: 29: /** 30: * See header. 31: */ 32: void ntru_bits_2_trits(uint8_t const *octets, uint16_t num_trits, uint8_t *trits) 33: { 34: uint32_t bits24, bits3, shift; 35: 36: while (num_trits >= 16) 37: { 38: /* get next three octets */ 39: bits24 = ((uint32_t)(*octets++)) << 16; 40: bits24 |= ((uint32_t)(*octets++)) << 8; 41: bits24 |= (uint32_t)(*octets++); 42: 43: /* for each 3 bits in the three octets, output 2 trits */ 44: bits3 = (bits24 >> 21) & 0x7; 45: *trits++ = bits_2_trit1[bits3]; 46: *trits++ = bits_2_trit2[bits3]; 47: 48: bits3 = (bits24 >> 18) & 0x7; 49: *trits++ = bits_2_trit1[bits3]; 50: *trits++ = bits_2_trit2[bits3]; 51: 52: bits3 = (bits24 >> 15) & 0x7; 53: *trits++ = bits_2_trit1[bits3]; 54: *trits++ = bits_2_trit2[bits3]; 55: 56: bits3 = (bits24 >> 12) & 0x7; 57: *trits++ = bits_2_trit1[bits3]; 58: *trits++ = bits_2_trit2[bits3]; 59: 60: bits3 = (bits24 >> 9) & 0x7; 61: *trits++ = bits_2_trit1[bits3]; 62: *trits++ = bits_2_trit2[bits3]; 63: 64: bits3 = (bits24 >> 6) & 0x7; 65: *trits++ = bits_2_trit1[bits3]; 66: *trits++ = bits_2_trit2[bits3]; 67: 68: bits3 = (bits24 >> 3) & 0x7; 69: *trits++ = bits_2_trit1[bits3]; 70: *trits++ = bits_2_trit2[bits3]; 71: 72: bits3 = bits24 & 0x7; 73: *trits++ = bits_2_trit1[bits3]; 74: *trits++ = bits_2_trit2[bits3]; 75: 76: num_trits -= 16; 77: } 78: if (num_trits == 0) 79: { 80: return; 81: } 82: 83: /* get three octets */ 84: bits24 = ((uint32_t)(*octets++)) << 16; 85: bits24 |= ((uint32_t)(*octets++)) << 8; 86: bits24 |= (uint32_t)(*octets++); 87: 88: shift = 21; 89: while (num_trits) 90: { 91: /** 92: * for each 3 bits in the three octets, output up to 2 trits 93: * until all trits needed are produced 94: */ 95: bits3 = (bits24 >> shift) & 0x7; 96: shift -= 3; 97: *trits++ = bits_2_trit1[bits3]; 98: if (--num_trits) 99: { 100: *trits++ = bits_2_trit2[bits3]; 101: --num_trits; 102: } 103: } 104: } 105: 106: /** 107: * See header. 108: */ 109: bool ntru_trits_2_bits(uint8_t const *trits, uint32_t num_trits, uint8_t *octets) 110: { 111: bool all_trits_valid = TRUE; 112: uint32_t bits24, bits3, shift; 113: 114: while (num_trits >= 16) 115: { 116: /* convert each 2 trits to 3 bits and pack */ 117: bits3 = *trits++ * 3; 118: bits3 += *trits++; 119: if (bits3 > 7) 120: { 121: bits3 = 7; 122: all_trits_valid = FALSE; 123: } 124: bits24 = (bits3 << 21); 125: 126: bits3 = *trits++ * 3; 127: bits3 += *trits++; 128: if (bits3 > 7) 129: { 130: bits3 = 7; 131: all_trits_valid = FALSE; 132: } 133: bits24 |= (bits3 << 18); 134: 135: bits3 = *trits++ * 3; 136: bits3 += *trits++; 137: if (bits3 > 7) 138: { 139: bits3 = 7; 140: all_trits_valid = FALSE; 141: } 142: bits24 |= (bits3 << 15); 143: 144: bits3 = *trits++ * 3; 145: bits3 += *trits++; 146: if (bits3 > 7) 147: { 148: bits3 = 7; 149: all_trits_valid = FALSE; 150: } 151: bits24 |= (bits3 << 12); 152: 153: bits3 = *trits++ * 3; 154: bits3 += *trits++; 155: if (bits3 > 7) 156: { 157: bits3 = 7; 158: all_trits_valid = FALSE; 159: } 160: bits24 |= (bits3 << 9); 161: 162: bits3 = *trits++ * 3; 163: bits3 += *trits++; 164: if (bits3 > 7) 165: { 166: bits3 = 7; 167: all_trits_valid = FALSE; 168: } 169: bits24 |= (bits3 << 6); 170: 171: bits3 = *trits++ * 3; 172: bits3 += *trits++; 173: if (bits3 > 7) 174: { 175: bits3 = 7; 176: all_trits_valid = FALSE; 177: } 178: bits24 |= (bits3 << 3); 179: 180: bits3 = *trits++ * 3; 181: bits3 += *trits++; 182: if (bits3 > 7) 183: { 184: bits3 = 7; 185: all_trits_valid = FALSE; 186: } 187: bits24 |= bits3; 188: 189: num_trits -= 16; 190: 191: /* output three octets */ 192: *octets++ = (uint8_t)((bits24 >> 16) & 0xff); 193: *octets++ = (uint8_t)((bits24 >> 8) & 0xff); 194: *octets++ = (uint8_t)(bits24 & 0xff); 195: } 196: 197: bits24 = 0; 198: shift = 21; 199: while (num_trits) 200: { 201: /* convert each 2 trits to 3 bits and pack */ 202: bits3 = *trits++ * 3; 203: if (--num_trits) 204: { 205: bits3 += *trits++; 206: --num_trits; 207: } 208: if (bits3 > 7) 209: { 210: bits3 = 7; 211: all_trits_valid = FALSE; 212: } 213: bits24 |= (bits3 << shift); 214: shift -= 3; 215: } 216: 217: /* output three octets */ 218: *octets++ = (uint8_t)((bits24 >> 16) & 0xff); 219: *octets++ = (uint8_t)((bits24 >> 8) & 0xff); 220: *octets++ = (uint8_t)(bits24 & 0xff); 221: 222: return all_trits_valid; 223: } 224: 225: /** 226: * See header 227: */ 228: void ntru_coeffs_mod4_2_octets(uint16_t num_coeffs, uint16_t const *coeffs, uint8_t *octets) 229: { 230: uint8_t bits2; 231: int shift, i; 232: 233: *octets = 0; 234: shift = 6; 235: for (i = 0; i < num_coeffs; i++) 236: { 237: bits2 = (uint8_t)(coeffs[i] & 0x3); 238: *octets |= bits2 << shift; 239: shift -= 2; 240: if (shift < 0) 241: { 242: ++octets; 243: *octets = 0; 244: shift = 6; 245: } 246: } 247: } 248: 249: /** 250: * See header. 251: */ 252: void ntru_trits_2_octet(uint8_t const *trits, uint8_t *octet) 253: { 254: int i; 255: 256: *octet = 0; 257: for (i = 4; i >= 0; i--) 258: { 259: *octet = (*octet * 3) + trits[i]; 260: } 261: } 262: 263: /** 264: * See header. 265: */ 266: void ntru_octet_2_trits(uint8_t octet, uint8_t *trits) 267: { 268: int i; 269: 270: for (i = 0; i < 5; i++) 271: { 272: trits[i] = octet % 3; 273: octet = (octet - trits[i]) / 3; 274: } 275: } 276: 277: /** 278: * See header. 279: */ 280: void ntru_indices_2_trits(uint16_t in_len, uint16_t const *in, bool plus1, 281: uint8_t *out) 282: { 283: uint8_t trit = plus1 ? 1 : 2; 284: int i; 285: 286: for (i = 0; i < in_len; i++) 287: { 288: out[in[i]] = trit; 289: } 290: } 291: 292: /** 293: * See header. 294: */ 295: void ntru_packed_trits_2_indices(uint8_t const *in, uint16_t num_trits, 296: uint16_t *indices_plus1, 297: uint16_t *indices_minus1) 298: { 299: uint8_t trits[5]; 300: uint16_t i = 0; 301: int j; 302: 303: while (num_trits >= 5) 304: { 305: ntru_octet_2_trits(*in++, trits); 306: num_trits -= 5; 307: for (j = 0; j < 5; j++, i++) 308: { 309: if (trits[j] == 1) 310: { 311: *indices_plus1 = i; 312: ++indices_plus1; 313: } 314: else if (trits[j] == 2) 315: { 316: *indices_minus1 = i; 317: ++indices_minus1; 318: } 319: } 320: } 321: if (num_trits) 322: { 323: ntru_octet_2_trits(*in, trits); 324: for (j = 0; num_trits && (j < 5); j++, i++) 325: { 326: if (trits[j] == 1) 327: { 328: *indices_plus1 = i; 329: ++indices_plus1; 330: } 331: else if (trits[j] == 2) 332: { 333: *indices_minus1 = i; 334: ++indices_minus1; 335: } 336: --num_trits; 337: } 338: } 339: } 340: 341: /** 342: * See header. 343: */ 344: void ntru_indices_2_packed_trits(uint16_t const *indices, uint16_t num_plus1, 345: uint16_t num_minus1, uint16_t num_trits, 346: uint8_t *buf, uint8_t *out) 347: { 348: /* convert indices to an array of trits */ 349: memset(buf, 0, num_trits); 350: ntru_indices_2_trits(num_plus1, indices, TRUE, buf); 351: ntru_indices_2_trits(num_minus1, indices + num_plus1, FALSE, buf); 352: 353: /* pack the array of trits */ 354: while (num_trits >= 5) 355: { 356: ntru_trits_2_octet(buf, out); 357: num_trits -= 5; 358: buf += 5; 359: ++out; 360: } 361: if (num_trits) 362: { 363: uint8_t trits[5]; 364: 365: memcpy(trits, buf, num_trits); 366: memset(trits + num_trits, 0, sizeof(trits) - num_trits); 367: ntru_trits_2_octet(trits, out); 368: } 369: } 370: 371: /** 372: * See header 373: */ 374: void ntru_elements_2_octets(uint16_t in_len, uint16_t const *in, uint8_t n_bits, 375: uint8_t *out) 376: { 377: uint16_t temp; 378: int shift, i; 379: 380: /* pack */ 381: temp = 0; 382: shift = n_bits - 8; 383: i = 0; 384: while (i < in_len) 385: { 386: /* add bits to temp to fill an octet and output the octet */ 387: temp |= in[i] >> shift; 388: *out++ = (uint8_t)(temp & 0xff); 389: shift = 8 - shift; 390: if (shift < 1) 391: { 392: /* next full octet is in current input word */ 393: shift += n_bits; 394: temp = 0; 395: } 396: else 397: { 398: /* put remaining bits of input word in temp as partial octet, 399: * and increment index to next input word 400: */ 401: temp = in[i] << (uint16_t)shift; 402: ++i; 403: } 404: shift = n_bits - shift; 405: } 406: 407: /* output any bits remaining in last input word */ 408: if (shift != n_bits - 8) 409: { 410: *out++ = (uint8_t)(temp & 0xff); 411: } 412: } 413: 414: 415: /** 416: * See header. 417: */ 418: void ntru_octets_2_elements(uint16_t in_len, uint8_t const *in, uint8_t n_bits, 419: uint16_t *out) 420: { 421: uint16_t temp; 422: uint16_t mask = (1 << n_bits) - 1; 423: int shift, i; 424: 425: /* unpack */ 426: temp = 0; 427: shift = n_bits; 428: i = 0; 429: while (i < in_len) 430: { 431: shift = 8 - shift; 432: if (shift < 0) 433: { 434: /* the current octet will not fill the current element */ 435: shift += n_bits; 436: } 437: else 438: { 439: /* add bits from the current octet to fill the current element and 440: * output the element 441: */ 442: temp |= ((uint16_t)in[i]) >> shift; 443: *out++ = temp & mask; 444: temp = 0; 445: } 446: 447: /* add the remaining bits of the current octet to start an element */ 448: shift = n_bits - shift; 449: temp |= ((uint16_t)in[i]) << shift; 450: ++i; 451: } 452: }