| /* |
| * Copyright 2018 Google LLC. |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * https://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "sample_error.h" |
| |
| #include <cstdint> |
| #include <vector> |
| |
| #include <gmock/gmock.h> |
| #include <gtest/gtest.h> |
| #include "context.h" |
| #include "montgomery.h" |
| #include "symmetric_encryption.h" |
| #include "testing/parameters.h" |
| #include "testing/status_matchers.h" |
| #include "testing/status_testing.h" |
| #include "testing/testing_prng.h" |
| |
| namespace { |
| |
| using ::rlwe::testing::StatusIs; |
| using ::testing::HasSubstr; |
| |
| const int kTestingRounds = 10; |
| const std::vector<rlwe::Uint64> variances = {8, 15, 29, 50}; |
| |
| template <typename ModularInt> |
| class SampleErrorTest : public ::testing::Test {}; |
| TYPED_TEST_SUITE(SampleErrorTest, rlwe::testing::ModularIntTypes); |
| |
| TYPED_TEST(SampleErrorTest, CheckUpperBoundOnNoise) { |
| using Int = typename TypeParam::Int; |
| |
| auto prng = absl::make_unique<rlwe::testing::TestingPrng>(0); |
| |
| for (const auto& params : |
| rlwe::testing::ContextParameters<TypeParam>::Value()) { |
| ASSERT_OK_AND_ASSIGN(auto context, |
| rlwe::RlweContext<TypeParam>::Create(params)); |
| |
| for (auto variance : variances) { |
| for (int i = 0; i < kTestingRounds; i++) { |
| ASSERT_OK_AND_ASSIGN(std::vector<TypeParam> error, |
| rlwe::SampleFromErrorDistribution<TypeParam>( |
| context->GetN(), variance, prng.get(), |
| context->GetModulusParams())); |
| // Check that each coefficient is in [-2*variance, 2*variance] |
| for (int j = 0; j < context->GetN(); j++) { |
| Int reduced = error[j].ExportInt(context->GetModulusParams()); |
| if (reduced > (context->GetModulus() >> 1)) { |
| EXPECT_LT(context->GetModulus() - reduced, 2 * variance + 1); |
| } else { |
| EXPECT_LT(reduced, 2 * variance + 1); |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| TYPED_TEST(SampleErrorTest, FailOnTooLargeVariance) { |
| auto prng = absl::make_unique<rlwe::testing::TestingPrng>(0); |
| for (const auto& params : |
| rlwe::testing::ContextParameters<TypeParam>::Value()) { |
| ASSERT_OK_AND_ASSIGN(auto context, |
| rlwe::RlweContext<TypeParam>::Create(params)); |
| |
| rlwe::Uint64 variance = rlwe::kMaxVariance + 1; |
| EXPECT_THAT( |
| rlwe::SampleFromErrorDistribution<TypeParam>( |
| context->GetN(), variance, prng.get(), context->GetModulusParams()), |
| StatusIs( |
| absl::StatusCode::kInvalidArgument, |
| HasSubstr(absl::StrCat("The variance, ", variance, |
| ", must be at most ", rlwe::kMaxVariance)))); |
| } |
| } |
| |
| } // namespace |