blob: 9d81f9dd3c189ec624afedcb0697fcbdfd1266a0 [file] [log] [blame]
// Copyright (c) 2011 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net/base/dnssec_keyset.h"
#include <cryptohi.h>
#include <cryptoht.h>
#include <keyhi.h>
#include "base/logging.h"
#include "base/memory/scoped_ptr.h"
#include "base/time.h"
#include "crypto/nss_util.h"
#include "net/base/dns_util.h"
namespace {
// These are encoded AlgorithmIdentifiers for the given signature algorithm.
const unsigned char kRSAWithSHA1[] = {
0x30, 0xd, 0x6, 0x9, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0xd, 0x1, 0x1, 0x5, 5, 0
};
const unsigned char kRSAWithSHA256[] = {
0x30, 0xd, 0x6, 0x9, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0xd, 0x1, 0x1, 0xb, 5, 0
};
} // namespace
namespace net {
DNSSECKeySet::DNSSECKeySet()
: ignore_timestamps_(false) {
}
DNSSECKeySet::~DNSSECKeySet() {
}
bool DNSSECKeySet::AddKey(const base::StringPiece& dnskey) {
uint16 keyid = DNSKEYToKeyID(dnskey);
std::string der_encoded = ASN1WrapDNSKEY(dnskey);
if (der_encoded.empty())
return false;
keyids_.push_back(keyid);
public_keys_.push_back(der_encoded);
return true;
}
bool DNSSECKeySet::CheckSignature(
const base::StringPiece& name,
const base::StringPiece& zone,
const base::StringPiece& signature,
uint16 rrtype,
const std::vector<base::StringPiece>& rrdatas) {
// signature has this format:
// algorithm uint8
// labels uint8
// ttl uint32
// expires uint32
// begins uint32
// keyid uint16
//
// followed by the actual signature.
if (signature.size() < 16)
return false;
const unsigned char* sigdata =
reinterpret_cast<const unsigned char*>(signature.data());
uint8 algorithm = sigdata[0];
uint32 expires = static_cast<uint32>(sigdata[6]) << 24 |
static_cast<uint32>(sigdata[7]) << 16 |
static_cast<uint32>(sigdata[8]) << 8 |
static_cast<uint32>(sigdata[9]);
uint32 begins = static_cast<uint32>(sigdata[10]) << 24 |
static_cast<uint32>(sigdata[11]) << 16 |
static_cast<uint32>(sigdata[12]) << 8 |
static_cast<uint32>(sigdata[13]);
uint16 keyid = static_cast<uint16>(sigdata[14]) << 8 |
static_cast<uint16>(sigdata[15]);
if (!ignore_timestamps_) {
uint32 now = static_cast<uint32>(base::Time::Now().ToTimeT());
if (now < begins || now >= expires)
return false;
}
base::StringPiece sig(signature.data() + 16, signature.size() - 16);
// You should have RFC 4034, 3.1.8.1 open when reading this code.
unsigned signed_data_len = 0;
signed_data_len += 2; // rrtype
signed_data_len += 16; // (see signature format, above)
signed_data_len += zone.size();
for (std::vector<base::StringPiece>::const_iterator
i = rrdatas.begin(); i != rrdatas.end(); i++) {
signed_data_len += name.size();
signed_data_len += 2; // rrtype
signed_data_len += 2; // class
signed_data_len += 4; // ttl
signed_data_len += 2; // RRDATA length
signed_data_len += i->size();
}
scoped_array<unsigned char> signed_data(new unsigned char[signed_data_len]);
unsigned j = 0;
signed_data[j++] = static_cast<uint8>(rrtype >> 8);
signed_data[j++] = static_cast<uint8>(rrtype);
memcpy(&signed_data[j], sigdata, 16);
j += 16;
memcpy(&signed_data[j], zone.data(), zone.size());
j += zone.size();
for (std::vector<base::StringPiece>::const_iterator
i = rrdatas.begin(); i != rrdatas.end(); i++) {
memcpy(&signed_data[j], name.data(), name.size());
j += name.size();
signed_data[j++] = static_cast<uint8>(rrtype >> 8);
signed_data[j++] = static_cast<uint8>(rrtype);
signed_data[j++] = 0; // CLASS (always IN = 1)
signed_data[j++] = 1;
// Copy the TTL from |signature|.
memcpy(&signed_data[j], signature.data() + 2, sizeof(uint32));
j += sizeof(uint32);
unsigned rrdata_len = i->size();
signed_data[j++] = rrdata_len >> 8;
signed_data[j++] = rrdata_len;
memcpy(&signed_data[j], i->data(), i->size());
j += i->size();
}
DCHECK_EQ(j, signed_data_len);
base::StringPiece signature_algorithm;
if (algorithm == kDNSSEC_RSA_SHA1 ||
algorithm == kDNSSEC_RSA_SHA1_NSEC3) {
signature_algorithm = base::StringPiece(
reinterpret_cast<const char*>(kRSAWithSHA1),
sizeof(kRSAWithSHA1));
} else if (algorithm == kDNSSEC_RSA_SHA256) {
signature_algorithm = base::StringPiece(
reinterpret_cast<const char*>(kRSAWithSHA256),
sizeof(kRSAWithSHA256));
} else {
// Unknown algorithm.
return false;
}
// Check the signature with each trusted key which has a matching keyid.
DCHECK_EQ(public_keys_.size(), keyids_.size());
for (unsigned i = 0; i < public_keys_.size(); i++) {
if (keyids_[i] != keyid)
continue;
if (VerifySignature(
signature_algorithm, sig, public_keys_[i],
base::StringPiece(reinterpret_cast<const char*>(signed_data.get()),
signed_data_len))) {
return true;
}
}
return false;
}
// static
uint16 DNSSECKeySet::DNSKEYToKeyID(const base::StringPiece& dnskey) {
const unsigned char* data =
reinterpret_cast<const unsigned char*>(dnskey.data());
// RFC 4034: App B
uint32 ac = 0;
for (unsigned i = 0; i < dnskey.size(); i++) {
if (i & 1) {
ac += data[i];
} else {
ac += static_cast<uint32>(data[i]) << 8;
}
}
ac += (ac >> 16) & 0xffff;
return ac;
}
void DNSSECKeySet::IgnoreTimestamps() {
ignore_timestamps_ = true;
}
bool DNSSECKeySet::VerifySignature(
base::StringPiece signature_algorithm,
base::StringPiece signature,
base::StringPiece public_key,
base::StringPiece signed_data) {
// This code is largely a copy-and-paste from
// crypto/signature_verifier_nss.cc. We can't change
// crypto::SignatureVerifier to always use NSS because we want the ability to
// be FIPS 140-2 compliant. However, we can't use crypto::SignatureVerifier
// here because some platforms don't support SHA256 signatures. Therefore, we
// use NSS directly.
crypto::EnsureNSSInit();
CERTSubjectPublicKeyInfo* spki = NULL;
SECItem spki_der;
spki_der.type = siBuffer;
spki_der.data = (uint8*) public_key.data();
spki_der.len = public_key.size();
spki = SECKEY_DecodeDERSubjectPublicKeyInfo(&spki_der);
if (!spki)
return false;
SECKEYPublicKey* pub_key = SECKEY_ExtractPublicKey(spki);
SECKEY_DestroySubjectPublicKeyInfo(spki); // Done with spki.
if (!pub_key)
return false;
PLArenaPool* arena = PORT_NewArena(DER_DEFAULT_CHUNKSIZE);
if (!arena) {
SECKEY_DestroyPublicKey(pub_key);
return false;
}
SECItem sig_alg_der;
sig_alg_der.type = siBuffer;
sig_alg_der.data = (uint8*) signature_algorithm.data();
sig_alg_der.len = signature_algorithm.size();
SECAlgorithmID sig_alg_id;
SECStatus rv;
rv = SEC_QuickDERDecodeItem(arena, &sig_alg_id,
SEC_ASN1_GET(SECOID_AlgorithmIDTemplate),
&sig_alg_der);
if (rv != SECSuccess) {
SECKEY_DestroyPublicKey(pub_key);
PORT_FreeArena(arena, PR_TRUE);
return false;
}
SECItem sig;
sig.type = siBuffer;
sig.data = (uint8*) signature.data();
sig.len = signature.size();
SECOidTag hash_alg_tag;
VFYContext* vfy_context =
VFY_CreateContextWithAlgorithmID(pub_key, &sig,
&sig_alg_id, &hash_alg_tag,
NULL);
SECKEY_DestroyPublicKey(pub_key);
PORT_FreeArena(arena, PR_TRUE); // Done with sig_alg_id.
if (!vfy_context) {
// A corrupted RSA signature could be detected without the data, so
// VFY_CreateContextWithAlgorithmID may fail with SEC_ERROR_BAD_SIGNATURE
// (-8182).
return false;
}
rv = VFY_Begin(vfy_context);
if (rv != SECSuccess) {
NOTREACHED();
return false;
}
rv = VFY_Update(vfy_context, (uint8*) signed_data.data(), signed_data.size());
if (rv != SECSuccess) {
NOTREACHED();
return false;
}
rv = VFY_End(vfy_context);
VFY_DestroyContext(vfy_context, PR_TRUE);
return rv == SECSuccess;
}
// This is an ASN.1 encoded AlgorithmIdentifier for RSA
static const unsigned char kASN1AlgorithmIdentifierRSA[] = {
0x30, // SEQUENCE
0x0d, // length (11 bytes)
0x06, // OBJECT IDENTIFER
0x09, // length (9 bytes)
// OID 1.2.840.113549.1.1.1 (RSA)
0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01,
// NULL of length 0
0x05, 0x00,
};
// EncodeASN1Length assumes that |*length| contains the number of DER-encoded,
// length-prefixed ASN.1 bytes to follow and serialises the length to |out[*j]|
// and updates |j| and |length| accordingly.
static void EncodeASN1Length(unsigned char* out, unsigned* j,
unsigned* length) {
if ((*length - 1) < 128) {
(*length) -= 1;
out[(*j)++] = *length;
} else if ((*length - 2) < 256) {
(*length) -= 2;
out[(*j)++] = 0x80 | 1;
out[(*j)++] = *length;
} else {
(*length) -= 3;
out[(*j)++] = 0x80 | 2;
out[(*j)++] = *length >> 8;
out[(*j)++] = *length;
}
}
// AdvanceForASN1Length returns the number of bytes required to encode a ASN1
// DER length value of |remaining|.
static unsigned AdvanceForASN1Length(unsigned remaining) {
if (remaining < 128) {
return 1;
} else if (remaining < 256) {
return 2;
} else if (remaining < 65536) {
return 3;
} else {
NOTREACHED();
return 3;
}
}
// ASN1WrapDNSKEY converts the DNSKEY RDATA in |dnskey| into the ASN.1 wrapped
// format expected by NSS. To wit:
// SubjectPublicKeyInfo ::= SEQUENCE {
// algorithm AlgorithmIdentifier,
// subjectPublicKey BIT STRING }
std::string DNSSECKeySet::ASN1WrapDNSKEY(const base::StringPiece& dnskey) {
const unsigned char* data =
reinterpret_cast<const unsigned char*>(dnskey.data());
if (dnskey.size() < 5 || dnskey.size() > 32767)
return "";
const uint8 algorithm = data[3];
if (algorithm != kDNSSEC_RSA_SHA1 &&
algorithm != kDNSSEC_RSA_SHA1_NSEC3 &&
algorithm != kDNSSEC_RSA_SHA256) {
return "";
}
unsigned exp_length;
unsigned exp_offset;
// First we extract the public exponent.
if (data[4] == 0) {
if (dnskey.size() < 7)
return "";
exp_length = static_cast<unsigned>(data[5]) << 8 |
static_cast<unsigned>(data[6]);
exp_offset = 7;
} else {
exp_length = static_cast<unsigned>(data[4]);
exp_offset = 5;
}
// We refuse to deal with large public exponents.
if (exp_length > 3)
return "";
if (dnskey.size() < exp_offset + exp_length)
return "";
unsigned exp = 0;
for (unsigned i = 0; i < exp_length; i++) {
exp <<= 8;
exp |= static_cast<unsigned>(data[exp_offset + i]);
}
unsigned n_offset = exp_offset + exp_length;
unsigned n_length = dnskey.size() - n_offset;
// Anything smaller than 512 bits is too weak to be trusted.
if (n_length < 64)
return "";
// If the MSB of exp is true then we need to prefix a zero byte to stop the
// ASN.1 encoding from being negative.
if (exp & (1 << ((8 * exp_length) - 1)))
exp_length++;
// Likewise with the modulus
unsigned n_padding = data[n_offset] & 0x80 ? 1 : 0;
// We now calculate the length of the full ASN.1 encoded public key. We're
// working backwards from the end of the structure. Keep in mind that it's:
// SEQUENCE
// AlgorithmIdentifier
// BITSTRING
// SEQUENCE
// INTEGER
// INTEGER
unsigned length = 0;
length += exp_length; // exponent data
length++; // we know that |exp_length| < 128
length++; // INTEGER tag for exponent
length += n_length + n_padding;
length += AdvanceForASN1Length(n_length + n_padding);
length++; // INTEGER tag for modulus
length += AdvanceForASN1Length(length); // SEQUENCE length
length++; // SEQUENCE tag
length++; // BITSTRING unused bits
length += AdvanceForASN1Length(length); // BITSTRING length
length++; // BITSTRING tag
length += sizeof(kASN1AlgorithmIdentifierRSA);
length += AdvanceForASN1Length(length); // SEQUENCE length
length++; // SEQUENCE tag
scoped_array<unsigned char> out(new unsigned char[length]);
// Now we walk forwards and serialise the ASN.1, undoing the steps above.
unsigned j = 0;
out[j++] = 0x30; // SEQUENCE
length--;
EncodeASN1Length(out.get(), &j, &length);
memcpy(&out[j], kASN1AlgorithmIdentifierRSA,
sizeof(kASN1AlgorithmIdentifierRSA));
j += sizeof(kASN1AlgorithmIdentifierRSA);
length -= sizeof(kASN1AlgorithmIdentifierRSA);
out[j++] = 3; // BITSTRING tag
length--;
EncodeASN1Length(out.get(), &j, &length);
out[j++] = 0; // BITSTRING unused bits
length--;
out[j++] = 0x30; // SEQUENCE
length--;
EncodeASN1Length(out.get(), &j, &length);
out[j++] = 2; // INTEGER
length--;
unsigned l = n_length + n_padding;
if (l < 128) {
out[j++] = l;
length--;
} else if (l < 256) {
out[j++] = 0x80 | 1;
out[j++] = l;
length -= 2;
} else if (l < 65536) {
out[j++] = 0x80 | 2;
out[j++] = l >> 8;
out[j++] = l;
length -= 3;
} else {
NOTREACHED();
}
if (n_padding) {
out[j++] = 0;
length--;
}
memcpy(&out[j], &data[n_offset], n_length);
j += n_length;
length -= n_length;
out[j++] = 2; // INTEGER
length--;
out[j++] = exp_length;
length--;
for (unsigned i = exp_length - 1; i < exp_length; i--) {
out[j++] = exp >> (8 * i);
length--;
}
DCHECK_EQ(0u, length);
return std::string(reinterpret_cast<char*>(out.get()), j);
}
} // namespace net