| /* |
| * Copyright 2017 Google LLC. |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * https://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "ntt_parameters.h" |
| |
| #include <cstdint> |
| #include <vector> |
| |
| #include <gmock/gmock.h> |
| #include <gtest/gtest.h> |
| #include "absl/numeric/int128.h" |
| #include "constants.h" |
| #include "montgomery.h" |
| #include "status_macros.h" |
| #include "testing/parameters.h" |
| #include "testing/status_matchers.h" |
| #include "testing/status_testing.h" |
| |
| namespace { |
| |
| using ::rlwe::testing::StatusIs; |
| using ::testing::HasSubstr; |
| |
| template <typename ModularInt> |
| class NttParametersTest : public testing::Test {}; |
| TYPED_TEST_SUITE(NttParametersTest, rlwe::testing::ModularIntTypes); |
| |
| TYPED_TEST(NttParametersTest, LogNumCoeffsTooLarge) { |
| for (const auto& params : |
| rlwe::testing::ContextParameters<TypeParam>::Value()) { |
| // Do not create a context, since it creates NttParameters already. Instead, |
| // create the modulus parameters manually. |
| ASSERT_OK_AND_ASSIGN(auto modulus_params, |
| TypeParam::Params::Create(params.modulus)); |
| |
| int log_n = rlwe::kMaxLogNumCoeffs + 1; |
| EXPECT_THAT( |
| rlwe::InitializeNttParameters<TypeParam>(log_n, modulus_params.get()), |
| StatusIs( |
| ::absl::StatusCode::kInvalidArgument, |
| HasSubstr(absl::StrCat("log_n, ", log_n, ", must be less than ", |
| rlwe::kMaxLogNumCoeffs, ".")))); |
| |
| log_n = (sizeof(typename TypeParam::Int) * 8) - 1; |
| if (log_n <= rlwe::kMaxLogNumCoeffs) { |
| EXPECT_THAT( |
| rlwe::InitializeNttParameters<TypeParam>(log_n, modulus_params.get()), |
| StatusIs( |
| ::absl::StatusCode::kInvalidArgument, |
| HasSubstr(absl::StrCat( |
| "log_n, ", log_n, |
| ", does not fit into underlying ModularInt::Int type.")))); |
| } |
| } |
| } |
| |
| TYPED_TEST(NttParametersTest, PrimitiveNthRootOfUnity) { |
| unsigned int log_ns[] = {2u, 4u, 6u, 8u, 11u}; |
| unsigned int len = 5; |
| |
| for (const auto& params : |
| rlwe::testing::ContextParameters<TypeParam>::Value()) { |
| // Do not create a context, since it creates NttParameters already. Instead, |
| // create the modulus parameters manually. |
| ASSERT_OK_AND_ASSIGN(auto modulus_params, |
| TypeParam::Params::Create(params.modulus)); |
| |
| for (unsigned int i = 0; i < len; i++) { |
| ASSERT_OK_AND_ASSIGN(TypeParam w, |
| rlwe::internal::PrimitiveNthRootOfUnity<TypeParam>( |
| log_ns[i], modulus_params.get())); |
| unsigned int n = 1 << log_ns[i]; |
| |
| // Ensure it is really a n-th root of unity. |
| auto res = w.ModExp(n, modulus_params.get()); |
| auto one = TypeParam::ImportOne(modulus_params.get()); |
| EXPECT_EQ(res, one) << "Not an n-th root of unity."; |
| |
| // Ensure it is really a primitive n-th root of unity. |
| auto res2 = w.ModExp(n / 2, modulus_params.get()); |
| EXPECT_NE(res2, one) << "Not a primitive n-th root of unity."; |
| } |
| } |
| } |
| |
| TYPED_TEST(NttParametersTest, NttPsis) { |
| for (const auto& params : |
| rlwe::testing::ContextParameters<TypeParam>::Value()) { |
| // Do not create a context, since it creates NttParameters already. Instead, |
| // create the modulus parameters manually. |
| ASSERT_OK_AND_ASSIGN(auto modulus_params, |
| TypeParam::Params::Create(params.modulus)); |
| const size_t n = 1 << params.log_n; |
| // The values of psi should be the powers of the primitive 2n-th root of |
| // unity. |
| // Obtain the psis. |
| ASSERT_OK_AND_ASSIGN( |
| std::vector<TypeParam> psis, |
| rlwe::internal::NttPsis<TypeParam>(params.log_n, modulus_params.get())); |
| |
| // Verify that that the 0th entry is 1. |
| TypeParam one = TypeParam::ImportOne(modulus_params.get()); |
| EXPECT_EQ(one, psis[0]); |
| |
| // Verify that the 1th entry is a primitive 2n-th root of unity. |
| auto r1 = psis[1].ModExp(2 * n, modulus_params.get()); |
| auto r2 = psis[1].ModExp(n, modulus_params.get()); |
| EXPECT_EQ(one, r1); |
| EXPECT_NE(one, r2); |
| |
| // Verify that each subsequent entry is the appropriate power of the 1th |
| // entry. |
| for (unsigned int i = 2; i < n; i++) { |
| auto ri = psis[1].ModExp(i, modulus_params.get()); |
| EXPECT_EQ(psis[i], ri); |
| } |
| } |
| } |
| |
| TYPED_TEST(NttParametersTest, NttPsisBitrev) { |
| for (const auto& params : |
| rlwe::testing::ContextParameters<TypeParam>::Value()) { |
| // Do not create a context, since it creates NttParameters already. Instead, |
| // create the modulus parameters manually. |
| ASSERT_OK_AND_ASSIGN(auto modulus_params, |
| TypeParam::Params::Create(params.modulus)); |
| const size_t n = 1 << params.log_n; |
| |
| // The values of psi should be bitreversed. |
| // Target vector: obtain the psis in bitreversed order. |
| ASSERT_OK_AND_ASSIGN( |
| std::vector<TypeParam> psis_bitrev, |
| rlwe::NttPsisBitrev<TypeParam>(params.log_n, modulus_params.get())); |
| // Obtain the psis. |
| ASSERT_OK_AND_ASSIGN( |
| std::vector<TypeParam> psis, |
| rlwe::internal::NttPsis<TypeParam>(params.log_n, modulus_params.get())); |
| // Obtain the mapping for bitreversed order |
| std::vector<unsigned int> bit_rev = |
| rlwe::internal::BitrevArray(params.log_n); |
| |
| for (unsigned int i = 0; i < n; i++) { |
| EXPECT_EQ(psis_bitrev[i], psis[bit_rev[i]]); |
| } |
| } |
| } |
| |
| TYPED_TEST(NttParametersTest, NttPsisInvBitrev) { |
| for (const auto& params : |
| rlwe::testing::ContextParameters<TypeParam>::Value()) { |
| // Do not create a context, since it creates NttParameters already. Instead, |
| // create the modulus parameters manually. |
| ASSERT_OK_AND_ASSIGN(auto modulus_params, |
| TypeParam::Params::Create(params.modulus)); |
| const size_t n = 1 << params.log_n; |
| |
| // The values of the vectors should be psi^(-(brv[k]+1) for all k. |
| // Target vector: obtain the psi inv in bit reversed order. |
| ASSERT_OK_AND_ASSIGN( |
| std::vector<TypeParam> psis_inv_bitrev, |
| rlwe::NttPsisInvBitrev<TypeParam>(params.log_n, modulus_params.get())); |
| // Obtain the psis. |
| ASSERT_OK_AND_ASSIGN( |
| std::vector<TypeParam> psis, |
| rlwe::internal::NttPsis<TypeParam>(params.log_n, modulus_params.get())); |
| // Obtain the mapping for bitreversed order |
| std::vector<unsigned int> bit_rev = |
| rlwe::internal::BitrevArray(params.log_n); |
| |
| for (unsigned int i = 0; i < n; i++) { |
| EXPECT_EQ(modulus_params->One(), |
| psis_inv_bitrev[i] |
| .Mul(psis[1], modulus_params.get()) |
| .Mul(psis[bit_rev[i]], modulus_params.get()) |
| .ExportInt(modulus_params.get())); |
| } |
| } |
| } |
| |
| TEST(NttParametersRegularTest, Bitrev) { |
| for (unsigned int log_N = 2; log_N < 11; log_N++) { |
| unsigned int N = 1 << log_N; |
| std::vector<unsigned int> bit_rev = rlwe::internal::BitrevArray(log_N); |
| |
| // Visit each entry of the array. |
| for (unsigned int i = 0; i < N; i++) { |
| for (unsigned int j = 0; j < log_N; j++) { |
| // Ensure bit j of i is equal to bit (log_N - j) of bit_rev[i]. |
| rlwe::Uint64 mask1 = 1 << j; |
| rlwe::Uint64 mask2 = 1 << (log_N - j - 1); |
| EXPECT_EQ((i & mask1) == 0, (bit_rev[i] & mask2) == 0); |
| } |
| } |
| } |
| } |
| |
| TYPED_TEST(NttParametersTest, IncorrectNTTParams) { |
| for (const auto& params : |
| rlwe::testing::ContextParameters<TypeParam>::Value()) { |
| // Do not create a context, since it creates NttParameters already. Instead, |
| // create the modulus parameters manually. |
| // modulus + 2, will no longer be 1 mod 2*n |
| ASSERT_OK_AND_ASSIGN(auto modulus_params, |
| TypeParam::Params::Create(params.modulus + 2)); |
| |
| EXPECT_THAT( |
| rlwe::InitializeNttParameters<TypeParam>(params.log_n, |
| modulus_params.get()), |
| StatusIs(::absl::StatusCode::kInvalidArgument, |
| HasSubstr(absl::StrCat("modulus is not 1 mod 2n for logn, ", |
| params.log_n)))); |
| } |
| } |
| |
| // Test all the NTT Parameter fields. |
| TYPED_TEST(NttParametersTest, Initialize) { |
| for (const auto& params : |
| rlwe::testing::ContextParameters<TypeParam>::Value()) { |
| // Do not create a context, since it creates NttParameters already. Instead, |
| // create the modulus parameters manually. |
| ASSERT_OK_AND_ASSIGN(auto modulus_params, |
| TypeParam::Params::Create(params.modulus)); |
| const size_t n = 1 << params.log_n; |
| |
| ASSERT_OK_AND_ASSIGN(rlwe::NttParameters<TypeParam> ntt_params, |
| rlwe::InitializeNttParameters<TypeParam>( |
| params.log_n, modulus_params.get())); |
| |
| TypeParam one = TypeParam::ImportOne(modulus_params.get()); |
| |
| // Obtain the mapping for bitreversed order |
| std::vector<unsigned int> bit_rev = |
| rlwe::internal::BitrevArray(params.log_n); |
| |
| // Test first entry of psis in bitreversed order is one. |
| EXPECT_EQ(one, ntt_params.psis_bitrev[0]); |
| |
| // Test n/2-th (brv[1]-th) entry of psis in bitreversed order is a primitive |
| // 2n-th root of unity. |
| auto psi = ntt_params.psis_bitrev[bit_rev[1]]; |
| auto r1 = psi.ModExp(2 * n, modulus_params.get()); |
| auto r2 = psi.ModExp(n, modulus_params.get()); |
| EXPECT_EQ(one, r1); |
| EXPECT_NE(one, r2); |
| |
| // The values of psis should be the powers of the primitive 2n-th root of |
| // unity in bitreversed order. |
| for (unsigned int i = 0; i < n; i++) { |
| auto bi = psi.ModExp(i, modulus_params.get()); |
| EXPECT_EQ(ntt_params.psis_bitrev[bit_rev[i]], bi); |
| } |
| |
| // Test psis_inv_bitrev contains the inverses of the powers of psi in |
| // bitreversed order, each multiplied by the inverse of psi. |
| for (unsigned int i = 0; i < n; i++) { |
| EXPECT_EQ(modulus_params->One(), |
| ntt_params.psis_bitrev[i] |
| .Mul(psi, modulus_params.get()) |
| .Mul(ntt_params.psis_inv_bitrev[i], modulus_params.get()) |
| .ExportInt(modulus_params.get())); |
| } |
| } |
| } |
| |
| } // namespace |