| /* |
| * Copyright 2017 Google LLC. |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * https://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef RLWE_SYMMETRIC_ENCRYPTION_H_ |
| #define RLWE_SYMMETRIC_ENCRYPTION_H_ |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <vector> |
| |
| #include "error_params.h" |
| #include "polynomial.h" |
| #include "prng/integral_prng_types.h" |
| #include "prng/prng.h" |
| #include "sample_error.h" |
| #include "serialization.pb.h" |
| #include "status_macros.h" |
| |
| namespace rlwe { |
| |
| // This file implements the somewhat homomorphic symmetric-key encryption scheme |
| // from "Fully Homomorphic Encryption from Ring-LWE and Security for Key |
| // Dependent Messages" by Zvika Brakerski and Vinod Vaikuntanathan. This |
| // encryption scheme uses Ring Learning with Errors (RLWE). |
| // http://www.wisdom.weizmann.ac.il/~zvikab/localpapers/IdealHom.pdf |
| // |
| // The scheme has CPA security under the hardness of the |
| // Ring-Learning with Errors problem (see reference above for details). We do |
| // not implement protections against timing attacks. |
| // |
| // The encryption scheme in this file is not fully homomorphic. It does not |
| // implement any sort of bootstrapping. |
| |
| // Represents a ciphertext encrypted using a symmetric-key version of the ring |
| // learning-with-errors (RLWE) encryption scheme. See the comments that follow |
| // throughout this file for full details on the particular encryption scheme. |
| // |
| // This implementation supports the following homomorphic operations: |
| // - Homomorphic addition. |
| // - Scalar multiplication by a polynomial (absorption) |
| // - Homomorphic multiplication. |
| // |
| // This implementation is only "somewhat homomorphic," not fully homomorphic. |
| // There is no bootstrapping, so a limited number of homomorphic operations can |
| // be performed before so much error accumulates that decryption is impossible. |
| // |
| // Each ciphertext comprises a vector of polynomials <c0, ..., cN>. Initially, |
| // a ciphertext comprises a pair <c0, c1>. Homomorphic multiplications cause |
| // the vector to grow longer. |
| template <typename ModularInt> |
| class SymmetricRlweCiphertext { |
| using Int = typename ModularInt::Int; |
| // BigInt is required in order to multiply two Int and ensure that no overflow |
| // occurs during the multiplication of two ciphertexts. |
| using BigInt = typename ModularInt::BigInt; |
| |
| public: |
| // Default and copy constructors. |
| explicit SymmetricRlweCiphertext(const typename ModularInt::Params* params, |
| const ErrorParams<ModularInt>* error_params) |
| : modulus_params_(params), |
| error_params_(error_params), |
| power_of_s_(1), |
| error_(0) {} |
| SymmetricRlweCiphertext(const SymmetricRlweCiphertext& that) = default; |
| |
| // Create a ciphertext by supplying the vector of components. |
| explicit SymmetricRlweCiphertext(std::vector<Polynomial<ModularInt>> c, |
| int power_of_s, double error, |
| const typename ModularInt::Params* params, |
| const ErrorParams<ModularInt>* error_params) |
| : c_(std::move(c)), |
| modulus_params_(params), |
| error_params_(error_params), |
| power_of_s_(power_of_s), |
| error_(error) {} |
| |
| // Homomorphic addition: add the polynomials representing the ciphertexts |
| // component-wise. The example below demonstrates why this procedure works |
| // properly in the two-component case. The quantities a, s, m, t, and e are |
| // introduced during encryption and are explained in the SymmetricRlweKey |
| // class. |
| // |
| // (a1 * s + m1 + t * e1, -a1) |
| // + (a2 * s + m2 + t * e2, -a2) |
| // ------------------------------ |
| // ((a1 + a2) * s + (m1 + m2) + t * (e1 + e2), -(a1 + a2)) |
| // |
| // Substitute (a1 + a2) = a3, (e1 + e2) = e3: |
| // |
| // (a3 * s + (m1 + m2) + t * e3, -a3) |
| // |
| // This result is a valid ciphertext where the value of a has changed, the |
| // error has increased, and the encoded plaintext contains the sum of the |
| // plaintexts that were encoded in the original two ciphertexts. |
| rlwe::StatusOr<SymmetricRlweCiphertext> operator+( |
| const SymmetricRlweCiphertext& that) const { |
| SymmetricRlweCiphertext out = *this; |
| RLWE_RETURN_IF_ERROR(out.AddInPlace(that)); |
| return out; |
| } |
| |
| absl::Status AddInPlace(const SymmetricRlweCiphertext& that) { |
| if (power_of_s_ != that.power_of_s_) { |
| return absl::InvalidArgumentError( |
| "Ciphertexts must be encrypted with the same key power."); |
| } |
| |
| if (c_.size() < that.c_.size()) { |
| Polynomial<ModularInt> zero(that.c_[0].Len(), modulus_params_); |
| c_.resize(that.c_.size(), zero); |
| } |
| |
| for (int i = 0; i < that.c_.size(); i++) { |
| RLWE_RETURN_IF_ERROR(c_[i].AddInPlace(that.c_[i], modulus_params_)); |
| } |
| |
| error_ += that.error_; |
| return absl::OkStatus(); |
| } |
| |
| // Homomorphic subtraction: subtract the polynomials representing the |
| // ciphertexts component-wise. The example below demonstrates why this |
| // procedure works properly in the two-component case. The quantities a, s, m, |
| // t, and e are introduced during encryption and are explained in the |
| // SymmetricRlweKey class. |
| // |
| // (a1 * s + m1 + t * e1, -a1) |
| // - (a2 * s + m2 + t * e2, -a2) |
| // ------------------------------ |
| // ((a1 - a2) * s + (m1 - m2) + t * (e1 - e2), -(a1 - a2)) |
| // |
| // Substitute (a1 - a2) = a3, (e1 - e2) = e3: |
| // |
| // (a3 * s + (m1 - m2) + t * e3, -a3) |
| // |
| // This result is a valid ciphertext where the value of a has changed, the |
| // error has increased, and the encoded plaintext contains the sum of the |
| // plaintexts that were encoded in the original two ciphertexts. |
| rlwe::StatusOr<SymmetricRlweCiphertext> operator-( |
| const SymmetricRlweCiphertext& that) const { |
| SymmetricRlweCiphertext out = *this; |
| RLWE_RETURN_IF_ERROR(out.SubInPlace(that)); |
| return out; |
| } |
| |
| absl::Status SubInPlace(const SymmetricRlweCiphertext& that) { |
| if (power_of_s_ != that.power_of_s_) { |
| return absl::InvalidArgumentError( |
| "Ciphertexts must be encrypted with the same key power."); |
| } |
| |
| if (c_.size() < that.c_.size()) { |
| Polynomial<ModularInt> zero(that.c_[0].Len(), modulus_params_); |
| c_.resize(that.c_.size(), zero); |
| } |
| |
| for (int i = 0; i < that.c_.size(); i++) { |
| RLWE_RETURN_IF_ERROR(c_[i].SubInPlace(that.c_[i], modulus_params_)); |
| } |
| |
| error_ += that.error_; |
| return absl::OkStatus(); |
| } |
| |
| // Homomorphic absorbtion. Multiplies the current ciphertext {m1}_s (plaintext |
| // m1 encrypted with symmetric key s) by a plaintext m2, resulting in a |
| // ciphertext {m1 * m2}_s that stores m1 * m2 encrypted with symmetric key s. |
| // |
| // DO NOT CONFUSE THIS OPERATION WITH HOMOMORPHIC MULTIPLICATION. |
| // |
| // To perform this operation, multiply the each component of the |
| // ciphertext by the plaintext polynomial. The example below demonstrates why |
| // this procedure works properly in the two-component case. The quantities a, |
| // s, m, t, and e are introduced during encryption and are explained in the |
| // Encrypt() function later in this file. |
| // |
| // (a1 * s + m1 + t * e1, -a1) * p |
| // = (a1 * s * p + m1 * p + t * e1 * p) |
| // |
| // Substitute (a1 * p) = a2 and (e1 * p) = e2: |
| // |
| // (a2 * s + m1 * p + t * e2) |
| // |
| // This result is a valid ciphertext where the value of a has changed, the |
| // error has increased, and the encoded plaintext contains the product of |
| // m1 and p. |
| // |
| // A few more details about the multiplication that takes place: |
| // |
| // The value stored in the resulting ciphertext is (m1 * m2) (mod 2^N + 1) |
| // (mod t), where N is the number of coefficients in s (or m1 or m2, since |
| // the all have the same number of coefficients). In other words, the |
| // result is the remainder of (m1 * m2) mod the polynomial (2^N + 1) with |
| // each of the coefficients the ntaken mod t. Any coefficient between 0 and |
| // modulus / 2 is treated as a positive number for the purposes of the final |
| // (mod t); any coefficient between modulus/2 and modulus is treated as |
| // a negative number for the purposes of the final (mod t). |
| rlwe::StatusOr<SymmetricRlweCiphertext> operator*( |
| const Polynomial<ModularInt>& that) const { |
| SymmetricRlweCiphertext out = *this; |
| RLWE_RETURN_IF_ERROR(out.AbsorbInPlace(that)); |
| return out; |
| } |
| |
| absl::Status AbsorbInPlace(const Polynomial<ModularInt>& that) { |
| for (auto& component : this->c_) { |
| RLWE_RETURN_IF_ERROR(component.MulInPlace(that, modulus_params_)); |
| } |
| error_ *= error_params_->B_plaintext(); |
| return absl::OkStatus(); |
| } |
| |
| // Homomorphically absorb a plaintext scalar. This function is exactly like |
| // homomorphic absorb above, except the plaintext is a constant. |
| rlwe::StatusOr<SymmetricRlweCiphertext> operator*( |
| const ModularInt& that) const { |
| SymmetricRlweCiphertext out = *this; |
| RLWE_RETURN_IF_ERROR(out.AbsorbInPlace(that)); |
| return out; |
| } |
| |
| absl::Status AbsorbInPlace(const ModularInt& that) { |
| for (auto& component : this->c_) { |
| RLWE_RETURN_IF_ERROR(component.MulInPlace(that, modulus_params_)); |
| } |
| error_ *= static_cast<double>(that.ExportInt(modulus_params_)); |
| return absl::OkStatus(); |
| } |
| |
| // Homomorphic multiply. Given two ciphertexts {m1}_s, {m2}_s containing |
| // messages m1 and m2 encrypted with the same secret key s, return the |
| // ciphertext {m1 * m2}_s containing the product of the messages. |
| // |
| // To perform this operation, treat the two ciphertext vectors as polynomials |
| // and perform a polynomial multiplication: |
| // |
| // <c0, c1> * <c0', c1'> = <c0 * c0, c0 * c1 + c1 * c0, c1 * c1> |
| // |
| // If the two ciphertext vectors are of length m and n, the resulting |
| // ciphertext is of length m + n - 1. |
| // |
| // The details of the multiplication that takes place between m1 and m2 are |
| // the same as in the homomorphic absorb operation above (the other overload |
| // of the * operator). |
| rlwe::StatusOr<SymmetricRlweCiphertext> operator*( |
| const SymmetricRlweCiphertext& that) { |
| if (power_of_s_ != that.power_of_s_) { |
| return absl::InvalidArgumentError( |
| "Ciphertexts must be encrypted with the same key power."); |
| } |
| if (c_.size() <= 0 || that.c_.size() <= 0) { |
| return absl::InvalidArgumentError( |
| "Cannot multiply using an empty ciphertext."); |
| } |
| if (c_[0].Len() <= 0 || that.c_[0].Len() <= 0) { |
| return absl::InvalidArgumentError( |
| "Cannot multiply using an empty polynomial in the ciphertext."); |
| } |
| Polynomial<ModularInt> temp(c_[0].Len(), modulus_params_); |
| std::vector<Polynomial<ModularInt>> result(c_.size() + that.c_.size() - 1, |
| temp); |
| for (int i = 0; i < c_.size(); i++) { |
| for (int j = 0; j < that.c_.size(); j++) { |
| RLWE_ASSIGN_OR_RETURN(temp, c_[i].Mul(that.c_[j], modulus_params_)); |
| RLWE_RETURN_IF_ERROR(result[i + j].AddInPlace(temp, modulus_params_)); |
| } |
| } |
| |
| return SymmetricRlweCiphertext(std::move(result), power_of_s_, |
| error_ * that.error_, modulus_params_, |
| error_params_); |
| } |
| |
| // Convert this ciphertext from (mod p) to (mod q). |
| // Assumes that ModularInt::Int and ModularIntQ::Int are the same type. |
| // |
| // The current modulus (mod t) must be equal to modulus q (mod t). |
| // This will always be true. For NTT to work properly, any modulus must be |
| // of the form 2N + 1, where N is a power of 2. Likewise, the implementation |
| // requires that t is a power of 2. This means that, for any modulus q and |
| // modulus t allowed by the RLWE implementation, q % t == 1. |
| template <typename ModularIntQ> |
| rlwe::StatusOr<SymmetricRlweCiphertext<ModularIntQ>> SwitchModulus( |
| const NttParameters<ModularInt>* ntt_params_p, |
| const typename ModularIntQ::Params* modulus_params_q, |
| const NttParameters<ModularIntQ>* ntt_params_q, |
| const ErrorParams<ModularIntQ>* error_params_q, const Int& t) { |
| Int p = modulus_params_->modulus; |
| Int q = modulus_params_q->modulus; |
| |
| // Configuration error. |
| if (p % t != q % t) { |
| return absl::InvalidArgumentError("p % t != q % t"); |
| } |
| |
| SymmetricRlweCiphertext<ModularIntQ> output(modulus_params_q, |
| error_params_q); |
| output.power_of_s_ = power_of_s_; |
| // Overestimate the ratio of the two moduli. |
| double modulus_ratio = static_cast<double>(modulus_params_q->log_modulus) / |
| modulus_params_->log_modulus; |
| output.error_ = modulus_ratio * error_ + error_params_q->B_scale(); |
| |
| output.c_.reserve(c_.size()); |
| for (const Polynomial<ModularInt>& c : c_) { |
| // Extract each component of the ciphertext from NTT form. |
| std::vector<ModularInt> coeffs_p = |
| c.InverseNtt(ntt_params_p, modulus_params_); |
| std::vector<ModularIntQ> coeffs_q; |
| coeffs_q.reserve(coeffs_p.size()); |
| |
| // Convert each coefficient of the polynomial from (mod p) to (mod q) |
| for (const ModularInt& coeff_p : coeffs_p) { |
| Int int_p = coeff_p.ExportInt(modulus_params_); |
| |
| // Scale the integer. |
| Int int_q = static_cast<Int>(ModularInt::DivAndTruncate( |
| static_cast<BigInt>(int_p) * static_cast<BigInt>(q), |
| static_cast<BigInt>(p))); |
| |
| // Ensure that int_p = int_q mod t by changing int_q as little as |
| // possible. |
| Int int_p_mod_t = int_p % t; |
| Int int_q_mod_t = int_q % t; |
| Int adjustment_up = modulus_params_->Zero(); |
| Int adjustment_down = modulus_params_->Zero(); |
| |
| // Determine whether to adjust int_q up or down to make sure int_q = |
| // int_p (mod t). |
| adjustment_up = int_p_mod_t - int_q_mod_t; |
| adjustment_down = t + int_q_mod_t - int_p_mod_t; |
| if (int_p_mod_t < int_q_mod_t) { |
| adjustment_up = adjustment_up + t; |
| adjustment_down = adjustment_down - t; |
| } |
| |
| RLWE_ASSIGN_OR_RETURN(auto m_int_q, |
| ModularIntQ::ImportInt(int_q, modulus_params_q)); |
| if (adjustment_up > adjustment_down) { |
| RLWE_ASSIGN_OR_RETURN( |
| auto m_adjustment_up, |
| ModularIntQ::ImportInt(adjustment_up, modulus_params_q)); |
| // Adjust up. |
| coeffs_q.push_back( |
| std::move(m_adjustment_up.AddInPlace(m_int_q, modulus_params_q))); |
| } else { |
| RLWE_ASSIGN_OR_RETURN( |
| auto m_adjustment_down, |
| ModularIntQ::ImportInt(q - adjustment_down, modulus_params_q)); |
| // Adjust down. |
| coeffs_q.push_back(std::move( |
| m_adjustment_down.AddInPlace(m_int_q, modulus_params_q))); |
| } |
| } |
| |
| // Convert back to NTT. |
| output.c_.push_back(Polynomial<ModularIntQ>::ConvertToNtt( |
| std::move(coeffs_q), ntt_params_q, modulus_params_q)); |
| } |
| |
| return output; |
| } |
| |
| // Given a ciphertext c encrypting a plaintext p(x) under secret key s(x), |
| // returns a ciphertext c' encrypting p(x^power) under the secret key |
| // s(x^power). |
| // Power must be an odd non-negative integer less than 2 * num_coeffs. |
| // This method uses NTT conversions to apply the substitution in the |
| // coefficient domain, and should be avoided if performance is an issue. |
| // Substitutions of the form 2^j + 1 are used to obliviously expand a query |
| // ciphertext into a query vector. |
| rlwe::StatusOr<SymmetricRlweCiphertext> Substitute( |
| int substitution_power, |
| const NttParameters<ModularInt>* ntt_params) const { |
| SymmetricRlweCiphertext output(modulus_params_, error_params_); |
| output.c_.reserve(c_.size()); |
| |
| for (const Polynomial<ModularInt>& c : c_) { |
| RLWE_ASSIGN_OR_RETURN( |
| auto elt, |
| c.Substitute(substitution_power, ntt_params, modulus_params_)); |
| output.c_.push_back(std::move(elt)); |
| } |
| output.power_of_s_ = (power_of_s_ * substitution_power) % (2 * c_[0].Len()); |
| output.error_ = error_; |
| return output; |
| } |
| |
| rlwe::StatusOr<SerializedSymmetricRlweCiphertext> Serialize() const { |
| SerializedSymmetricRlweCiphertext output; |
| output.set_power_of_s(power_of_s_); |
| output.set_error(error_); |
| |
| for (const Polynomial<ModularInt>& c : c_) { |
| RLWE_ASSIGN_OR_RETURN(*output.add_c(), c.Serialize(modulus_params_)); |
| } |
| |
| return output; |
| } |
| |
| static rlwe::StatusOr<SymmetricRlweCiphertext> Deserialize( |
| const SerializedSymmetricRlweCiphertext& serialized, |
| const typename ModularInt::Params* modulus_params, |
| const ErrorParams<ModularInt>* error_params) { |
| SymmetricRlweCiphertext output(modulus_params, error_params); |
| output.power_of_s_ = serialized.power_of_s(); |
| output.error_ = serialized.error(); |
| |
| if (serialized.c_size() <= 0) { |
| return absl::InvalidArgumentError("Ciphertext cannot be empty."); |
| } else if (serialized.c_size() > kMaxNumCoeffs) { |
| return absl::InvalidArgumentError( |
| absl::StrCat("Number of coefficients, ", serialized.c_size(), |
| ", cannot be more than ", kMaxNumCoeffs, ".")); |
| } |
| |
| for (int i = 0; i < serialized.c_size(); i++) { |
| RLWE_ASSIGN_OR_RETURN(auto elt, Polynomial<ModularInt>::Deserialize( |
| serialized.c(i), modulus_params)); |
| output.c_.push_back(std::move(elt)); |
| } |
| |
| return output; |
| } |
| |
| // Accessors. |
| unsigned int Len() const { return c_.size(); } |
| |
| rlwe::StatusOr<Polynomial<ModularInt>> Component(int index) const { |
| if (0 > index || index >= c_.size()) { |
| return absl::InvalidArgumentError("Index out of range."); |
| } |
| return c_[index]; |
| } |
| |
| const typename ModularInt::Params* ModulusParams() const { |
| return modulus_params_; |
| } |
| const rlwe::ErrorParams<ModularInt>* ErrorParams() const { |
| return error_params_; |
| } |
| int PowerOfS() const { return power_of_s_; } |
| double Error() const { return error_; } |
| void SetError(double error) { error_ = error; } |
| |
| private: |
| // The ciphertext. |
| std::vector<Polynomial<ModularInt>> c_; |
| |
| // ModularInt parameters. |
| const typename ModularInt::Params* modulus_params_; |
| |
| // Error parameters. |
| const rlwe::ErrorParams<ModularInt>* error_params_; |
| |
| // The power a in s(x^a) that the ciphertext can be decrypted with. |
| int power_of_s_; |
| |
| // A heuristic on the error of the ciphertext. |
| double error_; |
| |
| // Make this class a friend of any version of this class, no matter the |
| // template. |
| template <typename Q> |
| friend class SymmetricRlweCiphertext; |
| }; |
| |
| // Holds a key that can be used to encrypt messages using the RLWE-based |
| // encryption scheme. |
| template <typename ModularInt> |
| class SymmetricRlweKey { |
| using Int = typename ModularInt::Int; |
| |
| public: |
| // Allow copy, copy-assign, move and move-assign. |
| SymmetricRlweKey(const SymmetricRlweKey&) = default; |
| SymmetricRlweKey& operator=(const SymmetricRlweKey&) = default; |
| SymmetricRlweKey(SymmetricRlweKey&&) = default; |
| SymmetricRlweKey& operator=(SymmetricRlweKey&&) = default; |
| ~SymmetricRlweKey() = default; |
| |
| // Static factory that samples a key from the error distribution. The |
| // polynomial representing the key must have a number of coefficients that is |
| // a power of two, which is enforced by the first argument. |
| // |
| // Does not take ownership of rand, modulus_params or ntt_params. |
| static rlwe::StatusOr<SymmetricRlweKey> Sample( |
| unsigned int log_num_coeffs, uint64_t variance, uint64_t log_t, |
| const typename ModularInt::Params* modulus_params, |
| const NttParameters<ModularInt>* ntt_params, SecurePrng* prng) { |
| RLWE_ASSIGN_OR_RETURN( |
| auto error, SampleFromErrorDistribution<ModularInt>( |
| 1 << log_num_coeffs, variance, prng, modulus_params)); |
| Polynomial<ModularInt> key = Polynomial<ModularInt>::ConvertToNtt( |
| std::move(error), ntt_params, modulus_params); |
| RLWE_ASSIGN_OR_RETURN( |
| auto t_mod, ModularInt::ImportInt((modulus_params->One() << log_t) + |
| modulus_params->One(), |
| modulus_params)); |
| return SymmetricRlweKey(std::move(key), variance, log_t, std::move(t_mod), |
| modulus_params, modulus_params, ntt_params); |
| } |
| |
| rlwe::StatusOr<SerializedNttPolynomial> Serialize() const { |
| return key_.Serialize(modulus_params_); |
| } |
| |
| // Deserialize using modulus params as also the plaintext modulus params. Use |
| // this when deserializing a non-modulus switched key. |
| static rlwe::StatusOr<SymmetricRlweKey> Deserialize( |
| Uint64 variance, Uint64 log_t, |
| const SerializedNttPolynomial& serialized_key, |
| const typename ModularInt::Params* modulus_params, |
| const NttParameters<ModularInt>* ntt_params) { |
| return Deserialize(variance, log_t, serialized_key, modulus_params, |
| modulus_params, ntt_params); |
| } |
| |
| static rlwe::StatusOr<SymmetricRlweKey> Deserialize( |
| Uint64 variance, Uint64 log_t, |
| const SerializedNttPolynomial& serialized_key, |
| const typename ModularInt::Params* modulus_params, |
| const typename ModularInt::Params* plaintext_modulus_params, |
| const NttParameters<ModularInt>* ntt_params) { |
| // Check that log_t is no larger than the log_modulus - 1. |
| if (log_t > modulus_params->log_modulus - 1) { |
| return absl::InvalidArgumentError(absl::StrCat( |
| "The value of log_t, ", log_t, ", must be smaller than ", |
| "log_modulus - 1, ", modulus_params->log_modulus - 1, ".")); |
| } |
| RLWE_ASSIGN_OR_RETURN( |
| Polynomial<ModularInt> key, |
| Polynomial<ModularInt>::Deserialize(serialized_key, modulus_params)); |
| RLWE_ASSIGN_OR_RETURN( |
| auto t_mod, |
| ModularInt::ImportInt((plaintext_modulus_params->One() << log_t) + |
| plaintext_modulus_params->One(), |
| plaintext_modulus_params)); |
| return SymmetricRlweKey(std::move(key), variance, log_t, std::move(t_mod), |
| modulus_params, plaintext_modulus_params, |
| ntt_params); |
| } |
| |
| // Generate a copy of this key in modulus q. |
| // |
| // The current modulus (mod t) must be equal to modulus q (mod t). This |
| // property is implicitly enforced by the design of the code as described |
| // by the corresponding comment on SymmetricRlweKey::SwitchModulus. This |
| // property is also dynamically enforced. |
| // |
| // The algorithms for modulus-switching ciphertexts and keys are similar but |
| // slightly different. In particular, RLWE keys are guaranteed to have small |
| // coefficients, and thus modulus switching can be made very simple. Hence |
| // we have 2 separate implementations of SwitchModulus for keys and |
| // ciphertexts. |
| template <typename ModularIntQ> |
| rlwe::StatusOr<SymmetricRlweKey<ModularIntQ>> SwitchModulus( |
| const typename ModularIntQ::Params* modulus_params_q, |
| const NttParameters<ModularIntQ>* ntt_params_q) const { |
| // Configuration failure. |
| Int t = (modulus_params_q->One() << log_t_) + modulus_params_q->One(); |
| if (modulus_params_->modulus % t != modulus_params_q->modulus % t) { |
| return absl::InvalidArgumentError("p % t != q % t"); |
| } |
| |
| typename ModularIntQ::Int p_mod_q = |
| modulus_params_->modulus % modulus_params_q->modulus; |
| std::vector<ModularInt> coeffs_p = |
| key_.InverseNtt(ntt_params_, modulus_params_); |
| std::vector<ModularIntQ> coeffs_q; |
| |
| // Convert each coefficient of the polynomial from (mod p) to (mod q) |
| for (const ModularInt& coeff_p : coeffs_p) { |
| // Ensure that negative numbers (mod p) are translated into negative |
| // numbers (mod q). |
| Int int_p = coeff_p.ExportInt(modulus_params_); |
| if (int_p > modulus_params_->modulus >> 1) { |
| int_p = int_p - p_mod_q; |
| } |
| |
| RLWE_ASSIGN_OR_RETURN(auto m_int_p, |
| ModularIntQ::ImportInt(int_p, modulus_params_q)); |
| coeffs_q.push_back(std::move(m_int_p)); |
| } |
| |
| // Convert back to NTT. |
| auto key_q = Polynomial<ModularIntQ>::ConvertToNtt( |
| std::move(coeffs_q), ntt_params_q, modulus_params_q); |
| |
| RLWE_ASSIGN_OR_RETURN( |
| auto t_mod, ModularInt::ImportInt((modulus_params_q->One() << log_t_) + |
| modulus_params_q->One(), |
| modulus_params_q)); |
| return SymmetricRlweKey<ModularIntQ>(std::move(key_q), variance_, log_t_, |
| std::move(t_mod), modulus_params_q, |
| modulus_params_q, ntt_params_q); |
| } |
| |
| // Given s(x), returns a secret key s(x^a). |
| // This performs an Inverse NTT on the key, substitutes the key in polynomial |
| // representation, and then performs an NTT again. |
| rlwe::StatusOr<SymmetricRlweKey> Substitute(const int power) const { |
| RLWE_ASSIGN_OR_RETURN( |
| auto t_mod, ModularInt::ImportInt((modulus_params_->One() << log_t_) + |
| modulus_params_->One(), |
| modulus_params_)); |
| RLWE_ASSIGN_OR_RETURN(auto sub, |
| key_.Substitute(power, ntt_params_, modulus_params_)); |
| return SymmetricRlweKey(std::move(sub), variance_, log_t_, std::move(t_mod), |
| modulus_params_, plaintext_modulus_params_, |
| ntt_params_); |
| } |
| |
| // Accessors. |
| unsigned int Len() const { return key_.Len(); } |
| const NttParameters<ModularInt>* NttParams() const { return ntt_params_; } |
| const typename ModularInt::Params* ModulusParams() const { |
| return modulus_params_; |
| } |
| const unsigned int BitsPerCoeff() const { return log_t_; } |
| const Uint64 Variance() const { return variance_; } |
| const unsigned int LogT() const { return log_t_; } |
| const ModularInt& PlaintextModulus() const { return t_mod_; } |
| const typename ModularInt::Params* PlaintextModulusParams() const { |
| return plaintext_modulus_params_; |
| } |
| const Polynomial<ModularInt>& Key() const { return key_; } |
| |
| // Add two homomorphic encryption keys. |
| rlwe::StatusOr<SymmetricRlweKey<ModularInt>> Add( |
| const SymmetricRlweKey<ModularInt>& other_key) { |
| if (variance_ != other_key.variance_) { |
| return absl::InvalidArgumentError(absl::StrCat( |
| "The variance of the other key, ", other_key.variance_, |
| ", is different than the variance of this key, ", variance_, ".")); |
| } |
| if (log_t_ != other_key.log_t_) { |
| return absl::InvalidArgumentError(absl::StrCat( |
| "The log_t of the other key, ", other_key.log_t_, |
| ", is different than the log_t of this key, ", log_t_, ".")); |
| } |
| if (t_mod_ != other_key.t_mod_) { |
| return absl::InvalidArgumentError( |
| absl::StrCat("The plaintext space of the other key is different than " |
| "the plaintext space of this key.")); |
| } |
| RLWE_ASSIGN_OR_RETURN(auto key, key_.Add(other_key.key_, modulus_params_)); |
| return SymmetricRlweKey<ModularInt>(std::move(key), variance_, log_t_, |
| t_mod_, modulus_params_, |
| plaintext_modulus_params_, ntt_params_); |
| } |
| |
| // Substract two homomorphic encryption keys. |
| rlwe::StatusOr<SymmetricRlweKey<ModularInt>> Sub( |
| const SymmetricRlweKey<ModularInt>& other_key) { |
| if (variance_ != other_key.variance_) { |
| return absl::InvalidArgumentError(absl::StrCat( |
| "The variance of the other key, ", other_key.variance_, |
| ", is different than the variance of this key, ", variance_, ".")); |
| } |
| if (log_t_ != other_key.log_t_) { |
| return absl::InvalidArgumentError(absl::StrCat( |
| "The log_t of the other key, ", other_key.log_t_, |
| ", is different than the log_t of this key, ", log_t_, ".")); |
| } |
| if (t_mod_ != other_key.t_mod_) { |
| return absl::InvalidArgumentError( |
| absl::StrCat("The plaintext space of the other key is different than " |
| "the plaintext space of this key.")); |
| } |
| RLWE_ASSIGN_OR_RETURN(auto key, key_.Sub(other_key.key_, modulus_params_)); |
| return SymmetricRlweKey<ModularInt>(std::move(key), variance_, log_t_, |
| t_mod_, modulus_params_, |
| plaintext_modulus_params_, ntt_params_); |
| } |
| |
| // Static function to create a null key (with value 0). |
| static rlwe::StatusOr<SymmetricRlweKey> NullKey( |
| unsigned int log_num_coeffs, Uint64 variance, Uint64 log_t, |
| const typename ModularInt::Params* modulus_params, |
| const NttParameters<ModularInt>* ntt_params) { |
| Polynomial<ModularInt> zero(1 << log_num_coeffs, modulus_params); |
| RLWE_ASSIGN_OR_RETURN( |
| auto t_mod, ModularInt::ImportInt((modulus_params->One() << log_t) + |
| modulus_params->One(), |
| modulus_params)); |
| return SymmetricRlweKey(std::move(zero), variance, log_t, std::move(t_mod), |
| modulus_params, modulus_params, ntt_params); |
| } |
| |
| private: |
| // The contents of the key itself. |
| Polynomial<ModularInt> key_; |
| |
| // The variance of the binomial distribution from which the key and error are |
| // drawn. |
| Uint64 variance_; |
| |
| // The maximum size of any one coefficient of the polynomial representing a |
| // plaintext message. |
| unsigned int log_t_; |
| ModularInt t_mod_; |
| |
| // NTT parameters. |
| const NttParameters<ModularInt>* ntt_params_; |
| |
| // ModularInt parameters. |
| const typename ModularInt::Params* modulus_params_; |
| const typename ModularInt::Params* plaintext_modulus_params_; |
| |
| // A constructor. Does not take ownership of params. |
| SymmetricRlweKey(Polynomial<ModularInt> key, Uint64 variance, |
| unsigned int log_t, ModularInt t_mod, |
| const typename ModularInt::Params* modulus_params, |
| const typename ModularInt::Params* plaintext_modulus_params, |
| const NttParameters<ModularInt>* ntt_params) |
| : key_(std::move(key)), |
| variance_(variance), |
| log_t_(log_t), |
| t_mod_(std::move(t_mod)), |
| ntt_params_(ntt_params), |
| modulus_params_(modulus_params), |
| plaintext_modulus_params_(plaintext_modulus_params) {} |
| |
| // Make this class a friend of any version of this class, no matter the |
| // template. |
| template <typename Q> |
| friend class SymmetricRlweKey; |
| }; |
| |
| // Encrypts the plaintext using ring learning-with-errors (RLWE) encryption. |
| // (b/79577340): The parameter t is specified by log_t right, but is equal to |
| // (1 << log_t) + 1 so that t is odd. This is to allow multiplicative inverses |
| // of powers of 2, which are used to compress and obliviously expand a query |
| // ciphertext. |
| // |
| // The scheme works as follows: |
| // KeyGen(n, modulus q, error distr): |
| // Sample a degree (n-1) polynomial whose coefficients are drawn from the |
| // error distribution (mod q). This is our secret key. Call it s. |
| // |
| // Encrypt(secret key s, plaintext m, modulus q, modulus t, error distr): |
| // 1) Sample a degree (n-1) polynomial whose coefficients are drawn |
| // uniformly from any integer (mod q). Call this polynomial a. |
| // 2) Sample a degree (n-1) polynomial whose coefficients are drawn from |
| // the error distribution (mod q). Call this polynomial e. |
| // 3) Our secret key s and plaintext m are both degree (n-1) polynomials. |
| // For decryption to work, each coefficient of m must be < t. |
| // Compute (a * s + t * e + m) (mod x^n + 1). Call this polynomial b. |
| // 4) The ciphertext is the pair (b, -a). We refer to the pair of |
| // polynomials representing a ciphertext as (c0, c1) = |
| // (a * s + m + e * t, -a). |
| // |
| // Decrypt(secret key s, ciphertext (b, -a), modulus t): |
| // // Decryption when the ciphertext has two components. |
| // Compute and return (b - as) (mod t). Doing out the algebra: |
| // b - as (mod t) |
| // = as + te + m - as (mod t) |
| // = te + m (mod t) |
| // = m |
| // Quoting the paper, "the condition for correct decryption is that the |
| // L_infinity norm of the polynomial [te + m] is smaller than q/2." In |
| // other words, the largest of the values te + m (recall that e is |
| // sampled from a distribution) cannot exceed q/2. |
| // |
| // When the ciphertext has more than two components <c0, c1, ..., cN>, |
| // it can be decrypted by taking the dot product with the vector |
| // <s^0, s^1, ..., s^N> containing powers of the secret key: |
| // te + m = <c0, 1, ..., cN> dot <s^0, s^1, ..., s^N> |
| // = c0 * s^0 + c1 * s^1 + ... + cN * s^N |
| // |
| // Note that the Encrypt() function takes the original plaintext as |
| // an Polynomial<ModularInt>, while the corresponding Decrypt() method |
| // returns a std::vector<typename ModularInt::Int>. The two values will be the |
| // same once the original plaintext is converted out of NTT and Montgomery form. |
| // - The Encrypt() function takes an NTT polynomial so that, if the same |
| // plaintext is to be encrypted repeatedly, the NTT conversion only needs |
| // to be performed once by the caller. |
| // - The Decrypt() function returns a vector of integers because the final |
| // (mod t) step requires taking the polynomial (te + m) out of NTT and |
| // Montgomery form. |
| // It would be straightforward to write a wrapper of Encrypt() that takes |
| // a vector of integers as input, thereby making the plaintext types of the |
| // Encrypt() and Decrypt() functions symmetric. |
| |
| namespace internal { |
| |
| // This functions allows injecting a specific polynomial "a" as the randomness |
| // of the encryption (that is the negation of the c1 component of the |
| // ciphertext) and returns only the resulting c1 component of the ciphertext. |
| // This function is intended for internal use only. |
| template <typename ModularInt> |
| rlwe::StatusOr<Polynomial<ModularInt>> Encrypt( |
| const SymmetricRlweKey<ModularInt>& key, |
| const Polynomial<ModularInt>& plaintext, const Polynomial<ModularInt>& a, |
| SecurePrng* prng) { |
| // Sample the error term from the error distribution. |
| unsigned int num_coeffs = key.Len(); |
| RLWE_ASSIGN_OR_RETURN( |
| std::vector<ModularInt> e_coeffs, |
| SampleFromErrorDistribution<ModularInt>(num_coeffs, key.Variance(), prng, |
| key.ModulusParams())); |
| |
| // Create and return c0. |
| auto e = Polynomial<ModularInt>::ConvertToNtt( |
| std::move(e_coeffs), key.NttParams(), key.ModulusParams()); |
| RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> temp, |
| a.Mul(key.Key(), key.ModulusParams())); |
| RLWE_RETURN_IF_ERROR( |
| e.MulInPlace(key.PlaintextModulus(), key.ModulusParams())); |
| RLWE_RETURN_IF_ERROR(temp.AddInPlace(e, key.ModulusParams())); |
| RLWE_RETURN_IF_ERROR(temp.AddInPlace(plaintext, key.ModulusParams())); |
| return temp; |
| } |
| |
| } // namespace internal |
| |
| // Encrypts the supplied plaintext using the given key. Randomness is drawn from |
| // the key's underlying ModulusParams. |
| template <typename ModularInt> |
| rlwe::StatusOr<SymmetricRlweCiphertext<ModularInt>> Encrypt( |
| const SymmetricRlweKey<ModularInt>& key, |
| const Polynomial<ModularInt>& plaintext, |
| const ErrorParams<ModularInt>* error_params, SecurePrng* prng) { |
| // Sample a from the uniform distribution. |
| RLWE_ASSIGN_OR_RETURN(auto a, SamplePolynomialFromPrng<ModularInt>( |
| key.Len(), prng, key.ModulusParams())); |
| |
| // Create c0. |
| RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> c0, |
| internal::Encrypt(key, plaintext, a, prng)); |
| |
| // Compute c1 = -a and return the ciphertext. |
| return SymmetricRlweCiphertext<ModularInt>( |
| std::vector<Polynomial<ModularInt>>{ |
| std::move(c0), std::move(a.NegateInPlace(key.ModulusParams()))}, |
| 1, error_params->B_encryption(), key.ModulusParams(), error_params); |
| } |
| |
| // Takes as input the result of decrypting a RLWE plaintext that still contains |
| // the error. Concretely, it contains m + e * t (mod q). This function |
| // eliminates the error and returns the message. For reasons described below, |
| // this operation is more complicated than a simple (mod t). |
| // |
| // The error is drawn from a binomial distribution centered at zero and |
| // multiplied by t, meaning error values are either positive or negative |
| // multiples of t. Since each coefficient of the plaintext is smaller than |
| // t, some coefficients of the quantity m + e * t (which is all that's |
| // left in the vector error_and_message) could be negative. We are using |
| // modular arithmetic, so negative values become large positive values. |
| // |
| // Unfortunately, these negative values caues the naive error elimination |
| // strategy to fail. In theory we could take (m + e * t) mod t to |
| // eliminate the error portion and extract the message. However, consider |
| // a case where the error is negative. Suppose that t=2, m=1, and e=-1 |
| // with a modulus q=7: |
| // |
| // m + e * t (mod q) = |
| // 1 + -1 * 2 (mod 7) = |
| // -1 (mod 7) = |
| // 6 (mod 7) |
| // |
| // When we take 6 (mod t) = 6 (mod 2), we get 0, which is not the original |
| // bit of m. To avoid this problem, we treat negative values as negative |
| // values, not as their equivalents mod q. |
| // |
| // We consider (m + e * t) to be negative whenever it is between q/2 |
| // and q. Recall that, if |m + e * t| is greater than q/2, decryption |
| // fails. |
| // |
| // When the quantity (m + e * t) (mod q) represents a negative number |
| // mod q, we can re-create its non-modular negative form by computing |
| // ((m + e * t) - q). We can then take this value mod t to extract the |
| // correct answer. |
| // |
| // 1. (m + e * t (mod q)) = // in the range [q/2, q) |
| // 2. (m + e * t - q) = // in the range [-q/2, 0) |
| // 3. m (mod t) + e * t (mod t) - q (mod t) = // taken (mod t) |
| // 4. m - (q (mod t)) |
| // |
| // If we subtract q at step 2, we return negative numbers to their |
| // original form. Since we are going to perform a (mod t) operation |
| // anyway, we can subtract q (mod t) at step 2 to get the same result. |
| // Subtracting q (mod t) instead ensures that the quantity at step 2 |
| // does not become negative, which is convenient because we are using |
| // an unsigned integer type. |
| // |
| // Concluding the example from before with the fix: |
| // |
| // m + e * t (mod q) - q (mod t) = |
| // 1 + -1 * 2 (mod 7) - 7 (mod 2) = |
| // -1 (mod 7) - 7 (mod 2) = 6 - 1 = 5 |
| // |
| // 5 (mod t) = 1, which is the original message. |
| template <typename ModularInt> |
| std::vector<typename ModularInt::Int> RemoveError( |
| const std::vector<ModularInt>& error_and_message, |
| const typename ModularInt::Int& q, const typename ModularInt::Int& t, |
| const typename ModularInt::Params* modulus_params_q) { |
| using Int = typename ModularInt::Int; |
| Int q_mod_t = q % t; |
| Int zero = modulus_params_q->Zero(); |
| std::vector<Int> plaintext(error_and_message.size(), zero); |
| |
| for (int i = 0; i < error_and_message.size(); i++) { |
| plaintext[i] = error_and_message[i].ExportInt(modulus_params_q); |
| |
| if (plaintext[i] > (q >> 1)) { |
| plaintext[i] = plaintext[i] - q_mod_t; |
| } |
| |
| plaintext[i] = plaintext[i] % t; |
| } |
| |
| return plaintext; |
| } |
| |
| template <typename ModularInt> |
| rlwe::StatusOr<std::vector<typename ModularInt::Int>> Decrypt( |
| const SymmetricRlweKey<ModularInt>& key, |
| const SymmetricRlweCiphertext<ModularInt>& ciphertext) { |
| // Extract the error and message. To do so, take the dot product of the |
| // ciphertext vector <c0, c1, ..., cN> and the vector of the powers of |
| // the key <s^0, s^1, ..., s^N>. |
| |
| // Accumulator variables. |
| Polynomial<ModularInt> error_and_message_ntt(key.Len(), key.ModulusParams()); |
| Polynomial<ModularInt> key_powers = key.Key(); |
| unsigned int ciphertext_len = ciphertext.Len(); |
| |
| for (int i = 0; i < ciphertext_len; i++) { |
| // Extract component i. |
| RLWE_ASSIGN_OR_RETURN(Polynomial<ModularInt> ci, ciphertext.Component(i)); |
| |
| // Lazily increase the exponent of the key. |
| if (i > 1) { |
| RLWE_RETURN_IF_ERROR( |
| key_powers.MulInPlace(key.Key(), key.ModulusParams())); |
| } |
| |
| // Beyond c0, multiply the exponentiated key in. |
| if (i > 0) { |
| RLWE_RETURN_IF_ERROR( |
| ci.MulInPlace(key_powers, ciphertext.ModulusParams())); |
| } |
| |
| RLWE_RETURN_IF_ERROR( |
| error_and_message_ntt.AddInPlace(ci, key.ModulusParams())); |
| } |
| |
| // Invert the NTT process. |
| std::vector<ModularInt> error_and_message = |
| error_and_message_ntt.InverseNtt(key.NttParams(), key.ModulusParams()); |
| |
| // Extract the message. |
| return RemoveError<ModularInt>( |
| error_and_message, key.ModulusParams()->modulus, |
| key.PlaintextModulus().ExportInt(key.PlaintextModulusParams()), |
| key.ModulusParams()); |
| } |
| |
| } // namespace rlwe |
| |
| #endif // RLWE_SYMMETRIC_ENCRYPTION_H_ |