File:  [ELWIX - Embedded LightWeight unIX -] / embedaddon / strongswan / src / libstrongswan / plugins / ntru / ntru_convert.c
Revision 1.1.1.1 (vendor branch): download - view: text, annotated - select for diffs - revision graph
Wed Jun 3 09:46:44 2020 UTC (4 years, 5 months ago) by misho
Branches: strongswan, MAIN
CVS tags: v5_9_2p0, v5_8_4p7, HEAD
Strongswan

/*
 * Copyright (C) 2014 Andreas Steffen
 * HSR Hochschule fuer Technik Rapperswil
 *
 * Copyright (C) 2009-2013  Security Innovation
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2 of the License, or (at your
 * option) any later version.  See <http://www.fsf.org/copyleft/gpl.txt>.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
 * or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 */

#include <stdlib.h>
#include <string.h>

#include "ntru_convert.h"

/**
 * 3-bit to 2-trit conversion tables: 2 represents -1
 */
static uint8_t const bits_2_trit1[] = {0, 0, 0, 1, 1, 1, 2, 2};
static uint8_t const bits_2_trit2[] = {0, 1, 2, 0, 1, 2, 0, 1};

/**
 * See header.
 */
void ntru_bits_2_trits(uint8_t const *octets, uint16_t num_trits, uint8_t *trits)
{
	uint32_t bits24, bits3, shift;

	while (num_trits >= 16)
	{
		/* get next three octets */
		bits24  = ((uint32_t)(*octets++)) << 16;
		bits24 |= ((uint32_t)(*octets++)) <<  8;
		bits24 |=  (uint32_t)(*octets++);

		/* for each 3 bits in the three octets, output 2 trits */
		bits3 = (bits24 >> 21) & 0x7;
		*trits++ = bits_2_trit1[bits3];
		*trits++ = bits_2_trit2[bits3];

		bits3 = (bits24 >> 18) & 0x7;
		*trits++ = bits_2_trit1[bits3];
		*trits++ = bits_2_trit2[bits3];

		bits3 = (bits24 >> 15) & 0x7;
		*trits++ = bits_2_trit1[bits3];
		*trits++ = bits_2_trit2[bits3];

		bits3 = (bits24 >> 12) & 0x7;
		*trits++ = bits_2_trit1[bits3];
		*trits++ = bits_2_trit2[bits3];

		bits3 = (bits24 >>  9) & 0x7;
		*trits++ = bits_2_trit1[bits3];
		*trits++ = bits_2_trit2[bits3];

		bits3 = (bits24 >>  6) & 0x7;
		*trits++ = bits_2_trit1[bits3];
		*trits++ = bits_2_trit2[bits3];

		bits3 = (bits24 >>  3) & 0x7;
		*trits++ = bits_2_trit1[bits3];
		*trits++ = bits_2_trit2[bits3];

		bits3 = bits24 & 0x7;
		*trits++ = bits_2_trit1[bits3];
		*trits++ = bits_2_trit2[bits3];

		num_trits -= 16;
	}
	if (num_trits == 0)
	{
		return;
	}

	/* get three octets */
	bits24  = ((uint32_t)(*octets++)) << 16;
	bits24 |= ((uint32_t)(*octets++)) <<  8;
	bits24 |=  (uint32_t)(*octets++);

	shift = 21;
	while (num_trits)
	{
		/**
		 * for each 3 bits in the three octets, output up to 2 trits
		 * until all trits needed are produced
		 */
		bits3 = (bits24 >> shift) & 0x7;
		shift -= 3;
		*trits++ = bits_2_trit1[bits3];
		if (--num_trits)
		{
			*trits++ = bits_2_trit2[bits3];
			--num_trits;
		}
	}
}

/**
 * See header.
 */
bool ntru_trits_2_bits(uint8_t const *trits, uint32_t num_trits, uint8_t *octets)
{
	bool all_trits_valid = TRUE;
	uint32_t bits24, bits3, shift;

	while (num_trits >= 16)
	{
		/* convert each 2 trits to 3 bits and pack */
		bits3  = *trits++ * 3;
		bits3 += *trits++;
		if (bits3 > 7)
		{
			bits3 = 7;
			all_trits_valid = FALSE;
		}
		bits24 = (bits3 << 21);

		bits3  = *trits++ * 3;
		bits3 += *trits++;
		if (bits3 > 7)
		{
			bits3 = 7;
			all_trits_valid = FALSE;
		}
		bits24 |= (bits3 << 18);

		bits3  = *trits++ * 3;
		bits3 += *trits++;
		if (bits3 > 7)
		{
			bits3 = 7;
			all_trits_valid = FALSE;
		}
		bits24 |= (bits3 << 15);

		bits3  = *trits++ * 3;
		bits3 += *trits++;
		if (bits3 > 7)
		{
			bits3 = 7;
			all_trits_valid = FALSE;
		}
		bits24 |= (bits3 << 12);

		bits3  = *trits++ * 3;
		bits3 += *trits++;
		if (bits3 > 7)
		{
			bits3 = 7;
			all_trits_valid = FALSE;
		}
		bits24 |= (bits3 <<  9);

		bits3  = *trits++ * 3;
		bits3 += *trits++;
		if (bits3 > 7)
		{
			bits3 = 7;
			all_trits_valid = FALSE;
		}
		bits24 |= (bits3 <<  6);

		bits3  = *trits++ * 3;
		bits3 += *trits++;
		if (bits3 > 7)
		{
			bits3 = 7;
			all_trits_valid = FALSE;
		}
		bits24 |= (bits3 <<  3);

		bits3  = *trits++ * 3;
		bits3 += *trits++;
		if (bits3 > 7)
		{
			bits3 = 7;
			all_trits_valid = FALSE;
		}
		bits24 |= bits3;

		num_trits -= 16;

		/* output three octets */
		*octets++ = (uint8_t)((bits24 >> 16) & 0xff);
		*octets++ = (uint8_t)((bits24 >>  8) & 0xff);
		*octets++ = (uint8_t)(bits24 & 0xff);
	}

	bits24 = 0;
	shift = 21;
	while (num_trits)
	{
		/* convert each 2 trits to 3 bits and pack */
		bits3 = *trits++ * 3;
		if (--num_trits)
		{
			bits3 += *trits++;
			--num_trits;
		}
		if (bits3 > 7)
		{
			bits3 = 7;
			all_trits_valid = FALSE;
		}
		bits24 |= (bits3 << shift);
		shift -= 3;
	}

	/* output three octets */
	*octets++ = (uint8_t)((bits24 >> 16) & 0xff);
	*octets++ = (uint8_t)((bits24 >>  8) & 0xff);
	*octets++ = (uint8_t)(bits24 & 0xff);

	return all_trits_valid;
}

/**
 * See header
 */
void ntru_coeffs_mod4_2_octets(uint16_t num_coeffs, uint16_t const *coeffs, uint8_t *octets)
{
    uint8_t bits2;
    int shift, i;

	*octets = 0;
	shift = 6;
	for (i = 0; i < num_coeffs; i++)
	{
		bits2 = (uint8_t)(coeffs[i] & 0x3);
		*octets |= bits2 << shift;
		shift -= 2;
		if (shift < 0)
		{
			++octets;
			*octets = 0;
			shift = 6;
		}
	}
}

/**
 * See header.
 */
void ntru_trits_2_octet(uint8_t const *trits, uint8_t *octet)
{
	int i;

	*octet = 0;
	for (i = 4; i >= 0; i--)
	{
		*octet = (*octet * 3) + trits[i];
	}
}

/**
 * See header.
 */
void ntru_octet_2_trits(uint8_t octet, uint8_t *trits)
{
	int i;

	for (i = 0; i < 5; i++)
	{
		trits[i] = octet % 3;
		octet = (octet - trits[i]) / 3;
	}
}

/**
 * See header.
 */
void ntru_indices_2_trits(uint16_t in_len, uint16_t const *in, bool plus1,
						  uint8_t *out)
{
	uint8_t trit = plus1 ? 1 : 2;
	int  i;

    for (i = 0; i < in_len; i++)
	{
		out[in[i]] = trit;
	}
}

/**
 * See header.
 */
void ntru_packed_trits_2_indices(uint8_t const *in, uint16_t num_trits,
								 uint16_t *indices_plus1,
								 uint16_t *indices_minus1)
{
	uint8_t trits[5];
	uint16_t i = 0;
	int j;

	while (num_trits >= 5)
	{
		ntru_octet_2_trits(*in++, trits);
		num_trits -= 5;
		for (j = 0; j < 5; j++, i++)
		{
			if (trits[j] == 1)
			{
				*indices_plus1 = i;
				++indices_plus1;
			}
			else if (trits[j] == 2)
			{
				*indices_minus1 = i;
				++indices_minus1;
			}
		}
    }
	if (num_trits)
	{
		ntru_octet_2_trits(*in, trits);
		for (j = 0; num_trits && (j < 5); j++, i++)
		{
			if (trits[j] == 1)
			{
				*indices_plus1 = i;
				++indices_plus1;
			}
			else if (trits[j] == 2)
			{
				*indices_minus1 = i;
				++indices_minus1;
			}
			--num_trits;
		}
	}
}

/**
 * See header.
 */
void ntru_indices_2_packed_trits(uint16_t const *indices, uint16_t num_plus1,
								 uint16_t num_minus1, uint16_t num_trits,
								 uint8_t *buf, uint8_t *out)
{
	/* convert indices to an array of trits */
	memset(buf, 0, num_trits);
	ntru_indices_2_trits(num_plus1, indices, TRUE, buf);
	ntru_indices_2_trits(num_minus1, indices + num_plus1, FALSE, buf);

	/* pack the array of trits */
	while (num_trits >= 5)
	{
		ntru_trits_2_octet(buf, out);
		num_trits -= 5;
		buf += 5;
		++out;
	}
	if (num_trits)
	{
		uint8_t trits[5];

		memcpy(trits, buf, num_trits);
		memset(trits + num_trits, 0, sizeof(trits) - num_trits);
		ntru_trits_2_octet(trits, out);
	}
}

/**
 * See header
 */
void ntru_elements_2_octets(uint16_t in_len, uint16_t const *in, uint8_t n_bits,
							uint8_t *out)
{
	uint16_t temp;
	int shift, i;

	/* pack */
	temp = 0;
	shift = n_bits - 8;
	i = 0;
	while (i < in_len)
	{
		/* add bits to temp to fill an octet and output the octet */
		temp |= in[i] >> shift;
		*out++ = (uint8_t)(temp & 0xff);
		shift = 8 - shift;
		if (shift < 1)
		{
			/* next full octet is in current input word */
			shift += n_bits;
			temp = 0;
		}
		else
		{
			/* put remaining bits of input word in temp as partial octet,
			 * and increment index to next input word
			 */
			temp = in[i] << (uint16_t)shift;
			++i;
		}
		shift = n_bits - shift;
	}

	/* output any bits remaining in last input word */
	if (shift != n_bits - 8)
	{
		*out++ = (uint8_t)(temp & 0xff);
	}
}


/**
 * See header.
 */
void ntru_octets_2_elements(uint16_t in_len, uint8_t const *in, uint8_t n_bits,
							uint16_t *out)
{
	uint16_t  temp;
	uint16_t  mask = (1 << n_bits) - 1;
	int shift, i;

	/* unpack */
	temp = 0;
	shift = n_bits;
	i = 0;
	while (i < in_len)
	{
		shift = 8 - shift;
		if (shift < 0)
		{
			/* the current octet will not fill the current element */
			shift += n_bits;
		}
		else
		{
			/* add bits from the current octet to fill the current element and
			 * output the element
			 */
			temp |= ((uint16_t)in[i]) >> shift;
			*out++ = temp & mask;
			temp = 0;
		}

		/* add the remaining bits of the current octet to start an element */
		shift = n_bits - shift;
		temp |= ((uint16_t)in[i]) << shift;
		++i;
	}
}

FreeBSD-CVSweb <freebsd-cvsweb@FreeBSD.org>