// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <utility>

#include "base/bind.h"
#include "base/location.h"
#include "base/logging.h"
#include "base/single_thread_task_runner.h"
#include "base/stl_util.h"
#include "base/threading/thread_task_runner_handle.h"
#include "chrome/browser/local_discovery/service_discovery_client_impl.h"
#include "net/dns/public/dns_protocol.h"
#include "net/dns/record_rdata.h"

namespace local_discovery {

namespace {

// TODO(noamsml): Make this configurable through the LocalDomainResolver
// interface.
const int kLocalDomainSecondAddressTimeoutMs = 100;

const int kInitialRequeryTimeSeconds = 1;
const int kMaxRequeryTimeSeconds = 2;  // Time for last requery

}  // namespace

ServiceDiscoveryClientImpl::ServiceDiscoveryClientImpl(
    net::MDnsClient* mdns_client) : mdns_client_(mdns_client) {
}

ServiceDiscoveryClientImpl::~ServiceDiscoveryClientImpl() {
}

std::unique_ptr<ServiceWatcher>
ServiceDiscoveryClientImpl::CreateServiceWatcher(
    const std::string& service_type,
    const ServiceWatcher::UpdatedCallback& callback) {
  return std::make_unique<ServiceWatcherImpl>(service_type, callback,
                                              mdns_client_);
}

std::unique_ptr<ServiceResolver>
ServiceDiscoveryClientImpl::CreateServiceResolver(
    const std::string& service_name,
    ServiceResolver::ResolveCompleteCallback callback) {
  return std::make_unique<ServiceResolverImpl>(
      service_name, std::move(callback), mdns_client_);
}

std::unique_ptr<LocalDomainResolver>
ServiceDiscoveryClientImpl::CreateLocalDomainResolver(
    const std::string& domain,
    net::AddressFamily address_family,
    LocalDomainResolver::IPAddressCallback callback) {
  return std::make_unique<LocalDomainResolverImpl>(
      domain, address_family, std::move(callback), mdns_client_);
}

ServiceWatcherImpl::ServiceWatcherImpl(
    const std::string& service_type,
    const ServiceWatcher::UpdatedCallback& callback,
    net::MDnsClient* mdns_client)
    : service_type_(service_type), callback_(callback), started_(false),
      actively_refresh_services_(false), mdns_client_(mdns_client) {
}

void ServiceWatcherImpl::Start() {
  DCHECK(!started_);
  listener_ = mdns_client_->CreateListener(
      net::dns_protocol::kTypePTR, service_type_, this);
  started_ = listener_->Start();
  if (started_)
    ReadCachedServices();
}

ServiceWatcherImpl::~ServiceWatcherImpl() {
}

void ServiceWatcherImpl::DiscoverNewServices() {
  DCHECK(started_);
  SendQuery(kInitialRequeryTimeSeconds);
}

void ServiceWatcherImpl::SetActivelyRefreshServices(
    bool actively_refresh_services) {
  DCHECK(started_);
  actively_refresh_services_ = actively_refresh_services;

  for (auto& it : services_)
    it.second->SetActiveRefresh(actively_refresh_services);
}

void ServiceWatcherImpl::ReadCachedServices() {
  DCHECK(started_);
  CreateTransaction(false /*network*/, true /*cache*/, &transaction_cache_);
}

bool ServiceWatcherImpl::CreateTransaction(
    bool network,
    bool cache,
    std::unique_ptr<net::MDnsTransaction>* transaction) {
  int transaction_flags = 0;
  if (network)
    transaction_flags |= net::MDnsTransaction::QUERY_NETWORK;

  if (cache)
    transaction_flags |= net::MDnsTransaction::QUERY_CACHE;

  if (transaction_flags) {
    *transaction = mdns_client_->CreateTransaction(
        net::dns_protocol::kTypePTR, service_type_, transaction_flags,
        base::Bind(&ServiceWatcherImpl::OnTransactionResponse,
                   AsWeakPtr(), transaction));
    return (*transaction)->Start();
  }

  return true;
}

std::string ServiceWatcherImpl::GetServiceType() const {
  return listener_->GetName();
}

void ServiceWatcherImpl::OnRecordUpdate(
    net::MDnsListener::UpdateType update,
    const net::RecordParsed* record) {
  DCHECK(started_);
  if (record->type() == net::dns_protocol::kTypePTR) {
    DCHECK(record->name() == GetServiceType());
    const net::PtrRecordRdata* rdata = record->rdata<net::PtrRecordRdata>();

    switch (update) {
      case net::MDnsListener::RECORD_ADDED:
        AddService(rdata->ptrdomain());
        break;
      case net::MDnsListener::RECORD_CHANGED:
        NOTREACHED();
        break;
      case net::MDnsListener::RECORD_REMOVED:
        RemovePTR(rdata->ptrdomain());
        break;
    }
    return;
  }

  DCHECK(record->type() == net::dns_protocol::kTypeSRV ||
         record->type() == net::dns_protocol::kTypeTXT);
  DCHECK(services_.find(record->name()) != services_.end());

  if (record->type() == net::dns_protocol::kTypeSRV) {
    if (update == net::MDnsListener::RECORD_REMOVED)
      RemoveSRV(record->name());
    else if (update == net::MDnsListener::RECORD_ADDED)
      AddSRV(record->name());
  }

  // If this is the first time we see an SRV record, do not send
  // an UPDATE_CHANGED.
  if (record->type() != net::dns_protocol::kTypeSRV ||
      update != net::MDnsListener::RECORD_ADDED) {
    DeferUpdate(UPDATE_CHANGED, record->name());
  }
}

void ServiceWatcherImpl::OnCachePurged() {
  // Not yet implemented.
}

void ServiceWatcherImpl::OnTransactionResponse(
    std::unique_ptr<net::MDnsTransaction>* transaction,
    net::MDnsTransaction::Result result,
    const net::RecordParsed* record) {
  DCHECK(started_);
  if (result == net::MDnsTransaction::RESULT_RECORD) {
    AddService(record->rdata<net::PtrRecordRdata>()->ptrdomain());
  } else if (result == net::MDnsTransaction::RESULT_DONE) {
    transaction->reset();
  }

  // Do nothing for NSEC records. It is an error for hosts to broadcast an NSEC
  // record for PTR records on any name.
}

ServiceWatcherImpl::ServiceListeners::ServiceListeners(
    const std::string& service_name,
    ServiceWatcherImpl* watcher,
    net::MDnsClient* mdns_client)
    : service_name_(service_name), mdns_client_(mdns_client),
      update_pending_(false), has_ptr_(true), has_srv_(false) {
  srv_listener_ = mdns_client->CreateListener(
      net::dns_protocol::kTypeSRV, service_name, watcher);
  txt_listener_ = mdns_client->CreateListener(
      net::dns_protocol::kTypeTXT, service_name, watcher);
}

ServiceWatcherImpl::ServiceListeners::~ServiceListeners() {
}

bool ServiceWatcherImpl::ServiceListeners::Start() {
  return srv_listener_->Start() && txt_listener_->Start();
}

void ServiceWatcherImpl::ServiceListeners::SetActiveRefresh(
    bool active_refresh) {
  srv_listener_->SetActiveRefresh(active_refresh);

  if (active_refresh && !has_srv_) {
    DCHECK(has_ptr_);
    srv_transaction_ = mdns_client_->CreateTransaction(
        net::dns_protocol::kTypeSRV, service_name_,
        net::MDnsTransaction::SINGLE_RESULT |
        net::MDnsTransaction::QUERY_CACHE | net::MDnsTransaction::QUERY_NETWORK,
        base::Bind(&ServiceWatcherImpl::ServiceListeners::OnSRVRecord,
                   base::Unretained(this)));
    srv_transaction_->Start();
  } else if (!active_refresh) {
    srv_transaction_.reset();
  }
}

void ServiceWatcherImpl::ServiceListeners::OnSRVRecord(
    net::MDnsTransaction::Result result,
    const net::RecordParsed* record) {
  set_has_srv(!!record);
}

void ServiceWatcherImpl::ServiceListeners::set_has_srv(bool has_srv) {
  has_srv_ = has_srv;
  srv_transaction_.reset();
}

void ServiceWatcherImpl::AddService(const std::string& service) {
  DCHECK(started_);

  std::unique_ptr<ServiceListeners>& listener = services_[service];
  if (!listener) {
    listener.reset(new ServiceListeners(service, this, mdns_client_));
    bool success = listener->Start();
    DCHECK(success);
    listener->SetActiveRefresh(actively_refresh_services_);
    DeferUpdate(UPDATE_ADDED, service);
  }
  listener->set_has_ptr(true);
}

void ServiceWatcherImpl::AddSRV(const std::string& service) {
  DCHECK(started_);

  auto it = services_.find(service);
  if (it != services_.end())
    it->second->set_has_srv(true);
}

void ServiceWatcherImpl::DeferUpdate(ServiceWatcher::UpdateType update_type,
                                     const std::string& service_name) {
  auto it = services_.find(service_name);
  if (it != services_.end() && !it->second->update_pending()) {
    it->second->set_update_pending(true);
    base::ThreadTaskRunnerHandle::Get()->PostTask(
        FROM_HERE, base::BindOnce(&ServiceWatcherImpl::DeliverDeferredUpdate,
                                  AsWeakPtr(), update_type, service_name));
  }
}

void ServiceWatcherImpl::DeliverDeferredUpdate(
    ServiceWatcher::UpdateType update_type, const std::string& service_name) {
  auto it = services_.find(service_name);
  if (it != services_.end()) {
    it->second->set_update_pending(false);
    if (!callback_.is_null())
      callback_.Run(update_type, service_name);
  }
}

void ServiceWatcherImpl::RemovePTR(const std::string& service) {
  DCHECK(started_);

  auto it = services_.find(service);
  if (it != services_.end()) {
    it->second->set_has_ptr(false);
    if (!it->second->has_ptr_or_srv()) {
      services_.erase(it);
      if (!callback_.is_null())
        callback_.Run(UPDATE_REMOVED, service);
    }
  }
}

void ServiceWatcherImpl::RemoveSRV(const std::string& service) {
  DCHECK(started_);

  auto it = services_.find(service);
  if (it != services_.end()) {
    it->second->set_has_srv(false);
    if (!it->second->has_ptr_or_srv()) {
      services_.erase(it);
      if (!callback_.is_null())
        callback_.Run(UPDATE_REMOVED, service);
    }
  }
}

void ServiceWatcherImpl::OnNsecRecord(const std::string& name,
                                      unsigned rrtype) {
  // Do nothing. It is an error for hosts to broadcast an NSEC record for PTR
  // on any name.
}

void ServiceWatcherImpl::ScheduleQuery(int timeout_seconds) {
  if (timeout_seconds <= kMaxRequeryTimeSeconds) {
    base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
        FROM_HERE,
        base::BindOnce(&ServiceWatcherImpl::SendQuery, AsWeakPtr(),
                       timeout_seconds * 2 /*next_timeout_seconds*/),
        base::TimeDelta::FromSeconds(timeout_seconds));
  }
}

void ServiceWatcherImpl::SendQuery(int next_timeout_seconds) {
  CreateTransaction(true /*network*/, false /*cache*/, &transaction_network_);
  ScheduleQuery(next_timeout_seconds);
}

ServiceResolverImpl::ServiceResolverImpl(const std::string& service_name,
                                         ResolveCompleteCallback callback,
                                         net::MDnsClient* mdns_client)
    : service_name_(service_name),
      callback_(std::move(callback)),
      metadata_resolved_(false),
      address_resolved_(false),
      mdns_client_(mdns_client) {}

void ServiceResolverImpl::StartResolving() {
  address_resolved_ = false;
  metadata_resolved_ = false;
  service_staging_ = ServiceDescription();
  service_staging_.service_name = service_name_;

  if (!CreateTxtTransaction() || !CreateSrvTransaction()) {
    ServiceNotFound(ServiceResolver::STATUS_REQUEST_TIMEOUT);
  }
}

ServiceResolverImpl::~ServiceResolverImpl() {
}

bool ServiceResolverImpl::CreateTxtTransaction() {
  txt_transaction_ = mdns_client_->CreateTransaction(
      net::dns_protocol::kTypeTXT, service_name_,
      net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
      net::MDnsTransaction::QUERY_NETWORK,
      base::Bind(&ServiceResolverImpl::TxtRecordTransactionResponse,
                 AsWeakPtr()));
  return txt_transaction_->Start();
}

// TODO(noamsml): quick-resolve for AAAA records.  Since A records tend to be in
void ServiceResolverImpl::CreateATransaction() {
  a_transaction_ = mdns_client_->CreateTransaction(
      net::dns_protocol::kTypeA,
      service_staging_.address.host(),
      net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE,
      base::Bind(&ServiceResolverImpl::ARecordTransactionResponse,
                 AsWeakPtr()));
  a_transaction_->Start();
}

bool ServiceResolverImpl::CreateSrvTransaction() {
  srv_transaction_ = mdns_client_->CreateTransaction(
      net::dns_protocol::kTypeSRV, service_name_,
      net::MDnsTransaction::SINGLE_RESULT | net::MDnsTransaction::QUERY_CACHE |
      net::MDnsTransaction::QUERY_NETWORK,
      base::Bind(&ServiceResolverImpl::SrvRecordTransactionResponse,
                 AsWeakPtr()));
  return srv_transaction_->Start();
}

std::string ServiceResolverImpl::GetName() const {
  return service_name_;
}

void ServiceResolverImpl::SrvRecordTransactionResponse(
    net::MDnsTransaction::Result status, const net::RecordParsed* record) {
  srv_transaction_.reset();
  if (status != net::MDnsTransaction::RESULT_RECORD) {
    ServiceNotFound(MDnsStatusToRequestStatus(status));
    return;
  }

  DCHECK(record);
  service_staging_.address = RecordToAddress(record);
  service_staging_.last_seen = record->time_created();
  CreateATransaction();
}

void ServiceResolverImpl::TxtRecordTransactionResponse(
    net::MDnsTransaction::Result status, const net::RecordParsed* record) {
  txt_transaction_.reset();
  if (status == net::MDnsTransaction::RESULT_RECORD) {
    DCHECK(record);
    service_staging_.metadata = RecordToMetadata(record);
  } else {
    service_staging_.metadata.clear();
  }

  metadata_resolved_ = true;
  AlertCallbackIfReady();
}

void ServiceResolverImpl::ARecordTransactionResponse(
    net::MDnsTransaction::Result status, const net::RecordParsed* record) {
  a_transaction_.reset();

  if (status == net::MDnsTransaction::RESULT_RECORD) {
    DCHECK(record);
    service_staging_.ip_address = RecordToIPAddress(record);
  } else {
    service_staging_.ip_address = net::IPAddress();
  }

  address_resolved_ = true;
  AlertCallbackIfReady();
}

void ServiceResolverImpl::AlertCallbackIfReady() {
  if (metadata_resolved_ && address_resolved_) {
    txt_transaction_.reset();
    srv_transaction_.reset();
    a_transaction_.reset();
    if (!callback_.is_null())
      std::move(callback_).Run(STATUS_SUCCESS, service_staging_);
  }
}

void ServiceResolverImpl::ServiceNotFound(
    ServiceResolver::RequestStatus status) {
  txt_transaction_.reset();
  srv_transaction_.reset();
  a_transaction_.reset();
  if (!callback_.is_null())
    std::move(callback_).Run(status, ServiceDescription());
}

ServiceResolver::RequestStatus ServiceResolverImpl::MDnsStatusToRequestStatus(
    net::MDnsTransaction::Result status) const {
  switch (status) {
    case net::MDnsTransaction::RESULT_RECORD:
      return ServiceResolver::STATUS_SUCCESS;
    case net::MDnsTransaction::RESULT_NO_RESULTS:
      return ServiceResolver::STATUS_REQUEST_TIMEOUT;
    case net::MDnsTransaction::RESULT_NSEC:
      return ServiceResolver::STATUS_KNOWN_NONEXISTENT;
    case net::MDnsTransaction::RESULT_DONE:  // Pass through.
    default:
      NOTREACHED();
      return ServiceResolver::STATUS_REQUEST_TIMEOUT;
  }
}

const std::vector<std::string>& ServiceResolverImpl::RecordToMetadata(
    const net::RecordParsed* record) const {
  DCHECK(record->type() == net::dns_protocol::kTypeTXT);
  return record->rdata<net::TxtRecordRdata>()->texts();
}

net::HostPortPair ServiceResolverImpl::RecordToAddress(
    const net::RecordParsed* record) const {
  DCHECK(record->type() == net::dns_protocol::kTypeSRV);
  const net::SrvRecordRdata* srv_rdata = record->rdata<net::SrvRecordRdata>();
  return net::HostPortPair(srv_rdata->target(), srv_rdata->port());
}

net::IPAddress ServiceResolverImpl::RecordToIPAddress(
    const net::RecordParsed* record) const {
  DCHECK(record->type() == net::dns_protocol::kTypeA);
  return record->rdata<net::ARecordRdata>()->address();
}

LocalDomainResolverImpl::LocalDomainResolverImpl(
    const std::string& domain,
    net::AddressFamily address_family,
    IPAddressCallback callback,
    net::MDnsClient* mdns_client)
    : domain_(domain),
      address_family_(address_family),
      callback_(std::move(callback)),
      transactions_finished_(0),
      mdns_client_(mdns_client) {}

LocalDomainResolverImpl::~LocalDomainResolverImpl() {
  timeout_callback_.Cancel();
}

void LocalDomainResolverImpl::Start() {
  if (address_family_ == net::ADDRESS_FAMILY_IPV4 ||
      address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
    transaction_a_ = CreateTransaction(net::dns_protocol::kTypeA);
    transaction_a_->Start();
  }

  if (address_family_ == net::ADDRESS_FAMILY_IPV6 ||
      address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
    transaction_aaaa_ = CreateTransaction(net::dns_protocol::kTypeAAAA);
    transaction_aaaa_->Start();
  }
}

std::unique_ptr<net::MDnsTransaction>
LocalDomainResolverImpl::CreateTransaction(uint16_t type) {
  return mdns_client_->CreateTransaction(
      type, domain_, net::MDnsTransaction::SINGLE_RESULT |
                     net::MDnsTransaction::QUERY_CACHE |
                     net::MDnsTransaction::QUERY_NETWORK,
      base::Bind(&LocalDomainResolverImpl::OnTransactionComplete,
                 base::Unretained(this)));
}

void LocalDomainResolverImpl::OnTransactionComplete(
    net::MDnsTransaction::Result result, const net::RecordParsed* record) {
  transactions_finished_++;

  if (result == net::MDnsTransaction::RESULT_RECORD) {
    if (record->type() == net::dns_protocol::kTypeA) {
      address_ipv4_ = record->rdata<net::ARecordRdata>()->address();
    } else {
      DCHECK_EQ(net::dns_protocol::kTypeAAAA, record->type());
      address_ipv6_ = record->rdata<net::AAAARecordRdata>()->address();
    }
  }

  if (transactions_finished_ == 1 &&
      address_family_ == net::ADDRESS_FAMILY_UNSPECIFIED) {
    timeout_callback_.Reset(base::Bind(
        &LocalDomainResolverImpl::SendResolvedAddresses,
        base::Unretained(this)));

    base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
        FROM_HERE, timeout_callback_.callback(),
        base::TimeDelta::FromMilliseconds(kLocalDomainSecondAddressTimeoutMs));
  } else if (transactions_finished_ == 2
      || address_family_ != net::ADDRESS_FAMILY_UNSPECIFIED) {
    SendResolvedAddresses();
  }
}

bool LocalDomainResolverImpl::IsSuccess() {
  return address_ipv4_.IsValid() || address_ipv6_.IsValid();
}

void LocalDomainResolverImpl::SendResolvedAddresses() {
  transaction_a_.reset();
  transaction_aaaa_.reset();
  timeout_callback_.Cancel();
  std::move(callback_).Run(IsSuccess(), address_ipv4_, address_ipv6_);
}

}  // namespace local_discovery
