Annotation of embedaddon/strongswan/src/libstrongswan/plugins/ntru/ntru_convert.c, revision 1.1.1.1
1.1 misho 1: /*
2: * Copyright (C) 2014 Andreas Steffen
3: * HSR Hochschule fuer Technik Rapperswil
4: *
5: * Copyright (C) 2009-2013 Security Innovation
6: *
7: * This program is free software; you can redistribute it and/or modify it
8: * under the terms of the GNU General Public License as published by the
9: * Free Software Foundation; either version 2 of the License, or (at your
10: * option) any later version. See <http://www.fsf.org/copyleft/gpl.txt>.
11: *
12: * This program is distributed in the hope that it will be useful, but
13: * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
14: * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
15: * for more details.
16: */
17:
18: #include <stdlib.h>
19: #include <string.h>
20:
21: #include "ntru_convert.h"
22:
23: /**
24: * 3-bit to 2-trit conversion tables: 2 represents -1
25: */
26: static uint8_t const bits_2_trit1[] = {0, 0, 0, 1, 1, 1, 2, 2};
27: static uint8_t const bits_2_trit2[] = {0, 1, 2, 0, 1, 2, 0, 1};
28:
29: /**
30: * See header.
31: */
32: void ntru_bits_2_trits(uint8_t const *octets, uint16_t num_trits, uint8_t *trits)
33: {
34: uint32_t bits24, bits3, shift;
35:
36: while (num_trits >= 16)
37: {
38: /* get next three octets */
39: bits24 = ((uint32_t)(*octets++)) << 16;
40: bits24 |= ((uint32_t)(*octets++)) << 8;
41: bits24 |= (uint32_t)(*octets++);
42:
43: /* for each 3 bits in the three octets, output 2 trits */
44: bits3 = (bits24 >> 21) & 0x7;
45: *trits++ = bits_2_trit1[bits3];
46: *trits++ = bits_2_trit2[bits3];
47:
48: bits3 = (bits24 >> 18) & 0x7;
49: *trits++ = bits_2_trit1[bits3];
50: *trits++ = bits_2_trit2[bits3];
51:
52: bits3 = (bits24 >> 15) & 0x7;
53: *trits++ = bits_2_trit1[bits3];
54: *trits++ = bits_2_trit2[bits3];
55:
56: bits3 = (bits24 >> 12) & 0x7;
57: *trits++ = bits_2_trit1[bits3];
58: *trits++ = bits_2_trit2[bits3];
59:
60: bits3 = (bits24 >> 9) & 0x7;
61: *trits++ = bits_2_trit1[bits3];
62: *trits++ = bits_2_trit2[bits3];
63:
64: bits3 = (bits24 >> 6) & 0x7;
65: *trits++ = bits_2_trit1[bits3];
66: *trits++ = bits_2_trit2[bits3];
67:
68: bits3 = (bits24 >> 3) & 0x7;
69: *trits++ = bits_2_trit1[bits3];
70: *trits++ = bits_2_trit2[bits3];
71:
72: bits3 = bits24 & 0x7;
73: *trits++ = bits_2_trit1[bits3];
74: *trits++ = bits_2_trit2[bits3];
75:
76: num_trits -= 16;
77: }
78: if (num_trits == 0)
79: {
80: return;
81: }
82:
83: /* get three octets */
84: bits24 = ((uint32_t)(*octets++)) << 16;
85: bits24 |= ((uint32_t)(*octets++)) << 8;
86: bits24 |= (uint32_t)(*octets++);
87:
88: shift = 21;
89: while (num_trits)
90: {
91: /**
92: * for each 3 bits in the three octets, output up to 2 trits
93: * until all trits needed are produced
94: */
95: bits3 = (bits24 >> shift) & 0x7;
96: shift -= 3;
97: *trits++ = bits_2_trit1[bits3];
98: if (--num_trits)
99: {
100: *trits++ = bits_2_trit2[bits3];
101: --num_trits;
102: }
103: }
104: }
105:
106: /**
107: * See header.
108: */
109: bool ntru_trits_2_bits(uint8_t const *trits, uint32_t num_trits, uint8_t *octets)
110: {
111: bool all_trits_valid = TRUE;
112: uint32_t bits24, bits3, shift;
113:
114: while (num_trits >= 16)
115: {
116: /* convert each 2 trits to 3 bits and pack */
117: bits3 = *trits++ * 3;
118: bits3 += *trits++;
119: if (bits3 > 7)
120: {
121: bits3 = 7;
122: all_trits_valid = FALSE;
123: }
124: bits24 = (bits3 << 21);
125:
126: bits3 = *trits++ * 3;
127: bits3 += *trits++;
128: if (bits3 > 7)
129: {
130: bits3 = 7;
131: all_trits_valid = FALSE;
132: }
133: bits24 |= (bits3 << 18);
134:
135: bits3 = *trits++ * 3;
136: bits3 += *trits++;
137: if (bits3 > 7)
138: {
139: bits3 = 7;
140: all_trits_valid = FALSE;
141: }
142: bits24 |= (bits3 << 15);
143:
144: bits3 = *trits++ * 3;
145: bits3 += *trits++;
146: if (bits3 > 7)
147: {
148: bits3 = 7;
149: all_trits_valid = FALSE;
150: }
151: bits24 |= (bits3 << 12);
152:
153: bits3 = *trits++ * 3;
154: bits3 += *trits++;
155: if (bits3 > 7)
156: {
157: bits3 = 7;
158: all_trits_valid = FALSE;
159: }
160: bits24 |= (bits3 << 9);
161:
162: bits3 = *trits++ * 3;
163: bits3 += *trits++;
164: if (bits3 > 7)
165: {
166: bits3 = 7;
167: all_trits_valid = FALSE;
168: }
169: bits24 |= (bits3 << 6);
170:
171: bits3 = *trits++ * 3;
172: bits3 += *trits++;
173: if (bits3 > 7)
174: {
175: bits3 = 7;
176: all_trits_valid = FALSE;
177: }
178: bits24 |= (bits3 << 3);
179:
180: bits3 = *trits++ * 3;
181: bits3 += *trits++;
182: if (bits3 > 7)
183: {
184: bits3 = 7;
185: all_trits_valid = FALSE;
186: }
187: bits24 |= bits3;
188:
189: num_trits -= 16;
190:
191: /* output three octets */
192: *octets++ = (uint8_t)((bits24 >> 16) & 0xff);
193: *octets++ = (uint8_t)((bits24 >> 8) & 0xff);
194: *octets++ = (uint8_t)(bits24 & 0xff);
195: }
196:
197: bits24 = 0;
198: shift = 21;
199: while (num_trits)
200: {
201: /* convert each 2 trits to 3 bits and pack */
202: bits3 = *trits++ * 3;
203: if (--num_trits)
204: {
205: bits3 += *trits++;
206: --num_trits;
207: }
208: if (bits3 > 7)
209: {
210: bits3 = 7;
211: all_trits_valid = FALSE;
212: }
213: bits24 |= (bits3 << shift);
214: shift -= 3;
215: }
216:
217: /* output three octets */
218: *octets++ = (uint8_t)((bits24 >> 16) & 0xff);
219: *octets++ = (uint8_t)((bits24 >> 8) & 0xff);
220: *octets++ = (uint8_t)(bits24 & 0xff);
221:
222: return all_trits_valid;
223: }
224:
225: /**
226: * See header
227: */
228: void ntru_coeffs_mod4_2_octets(uint16_t num_coeffs, uint16_t const *coeffs, uint8_t *octets)
229: {
230: uint8_t bits2;
231: int shift, i;
232:
233: *octets = 0;
234: shift = 6;
235: for (i = 0; i < num_coeffs; i++)
236: {
237: bits2 = (uint8_t)(coeffs[i] & 0x3);
238: *octets |= bits2 << shift;
239: shift -= 2;
240: if (shift < 0)
241: {
242: ++octets;
243: *octets = 0;
244: shift = 6;
245: }
246: }
247: }
248:
249: /**
250: * See header.
251: */
252: void ntru_trits_2_octet(uint8_t const *trits, uint8_t *octet)
253: {
254: int i;
255:
256: *octet = 0;
257: for (i = 4; i >= 0; i--)
258: {
259: *octet = (*octet * 3) + trits[i];
260: }
261: }
262:
263: /**
264: * See header.
265: */
266: void ntru_octet_2_trits(uint8_t octet, uint8_t *trits)
267: {
268: int i;
269:
270: for (i = 0; i < 5; i++)
271: {
272: trits[i] = octet % 3;
273: octet = (octet - trits[i]) / 3;
274: }
275: }
276:
277: /**
278: * See header.
279: */
280: void ntru_indices_2_trits(uint16_t in_len, uint16_t const *in, bool plus1,
281: uint8_t *out)
282: {
283: uint8_t trit = plus1 ? 1 : 2;
284: int i;
285:
286: for (i = 0; i < in_len; i++)
287: {
288: out[in[i]] = trit;
289: }
290: }
291:
292: /**
293: * See header.
294: */
295: void ntru_packed_trits_2_indices(uint8_t const *in, uint16_t num_trits,
296: uint16_t *indices_plus1,
297: uint16_t *indices_minus1)
298: {
299: uint8_t trits[5];
300: uint16_t i = 0;
301: int j;
302:
303: while (num_trits >= 5)
304: {
305: ntru_octet_2_trits(*in++, trits);
306: num_trits -= 5;
307: for (j = 0; j < 5; j++, i++)
308: {
309: if (trits[j] == 1)
310: {
311: *indices_plus1 = i;
312: ++indices_plus1;
313: }
314: else if (trits[j] == 2)
315: {
316: *indices_minus1 = i;
317: ++indices_minus1;
318: }
319: }
320: }
321: if (num_trits)
322: {
323: ntru_octet_2_trits(*in, trits);
324: for (j = 0; num_trits && (j < 5); j++, i++)
325: {
326: if (trits[j] == 1)
327: {
328: *indices_plus1 = i;
329: ++indices_plus1;
330: }
331: else if (trits[j] == 2)
332: {
333: *indices_minus1 = i;
334: ++indices_minus1;
335: }
336: --num_trits;
337: }
338: }
339: }
340:
341: /**
342: * See header.
343: */
344: void ntru_indices_2_packed_trits(uint16_t const *indices, uint16_t num_plus1,
345: uint16_t num_minus1, uint16_t num_trits,
346: uint8_t *buf, uint8_t *out)
347: {
348: /* convert indices to an array of trits */
349: memset(buf, 0, num_trits);
350: ntru_indices_2_trits(num_plus1, indices, TRUE, buf);
351: ntru_indices_2_trits(num_minus1, indices + num_plus1, FALSE, buf);
352:
353: /* pack the array of trits */
354: while (num_trits >= 5)
355: {
356: ntru_trits_2_octet(buf, out);
357: num_trits -= 5;
358: buf += 5;
359: ++out;
360: }
361: if (num_trits)
362: {
363: uint8_t trits[5];
364:
365: memcpy(trits, buf, num_trits);
366: memset(trits + num_trits, 0, sizeof(trits) - num_trits);
367: ntru_trits_2_octet(trits, out);
368: }
369: }
370:
371: /**
372: * See header
373: */
374: void ntru_elements_2_octets(uint16_t in_len, uint16_t const *in, uint8_t n_bits,
375: uint8_t *out)
376: {
377: uint16_t temp;
378: int shift, i;
379:
380: /* pack */
381: temp = 0;
382: shift = n_bits - 8;
383: i = 0;
384: while (i < in_len)
385: {
386: /* add bits to temp to fill an octet and output the octet */
387: temp |= in[i] >> shift;
388: *out++ = (uint8_t)(temp & 0xff);
389: shift = 8 - shift;
390: if (shift < 1)
391: {
392: /* next full octet is in current input word */
393: shift += n_bits;
394: temp = 0;
395: }
396: else
397: {
398: /* put remaining bits of input word in temp as partial octet,
399: * and increment index to next input word
400: */
401: temp = in[i] << (uint16_t)shift;
402: ++i;
403: }
404: shift = n_bits - shift;
405: }
406:
407: /* output any bits remaining in last input word */
408: if (shift != n_bits - 8)
409: {
410: *out++ = (uint8_t)(temp & 0xff);
411: }
412: }
413:
414:
415: /**
416: * See header.
417: */
418: void ntru_octets_2_elements(uint16_t in_len, uint8_t const *in, uint8_t n_bits,
419: uint16_t *out)
420: {
421: uint16_t temp;
422: uint16_t mask = (1 << n_bits) - 1;
423: int shift, i;
424:
425: /* unpack */
426: temp = 0;
427: shift = n_bits;
428: i = 0;
429: while (i < in_len)
430: {
431: shift = 8 - shift;
432: if (shift < 0)
433: {
434: /* the current octet will not fill the current element */
435: shift += n_bits;
436: }
437: else
438: {
439: /* add bits from the current octet to fill the current element and
440: * output the element
441: */
442: temp |= ((uint16_t)in[i]) >> shift;
443: *out++ = temp & mask;
444: temp = 0;
445: }
446:
447: /* add the remaining bits of the current octet to start an element */
448: shift = n_bits - shift;
449: temp |= ((uint16_t)in[i]) << shift;
450: ++i;
451: }
452: }
FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>