blob: fdbc5f6ac701f5d6a44161078e9a642f5e617bf1 [file] [log] [blame] [edit]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "dns-proxy/resolver.h"
#include <utility>
#include <base/bind.h>
#include <base/logging.h>
#include <base/memory/ref_counted.h>
#include <base/optional.h>
#include <base/rand_util.h>
#include <base/threading/thread_task_runner_handle.h>
#include <chromeos/patchpanel/dns/dns_protocol.h>
#include <chromeos/patchpanel/dns/dns_query.h>
#include <chromeos/patchpanel/dns/io_buffer.h>
#include <chromeos/patchpanel/net_util.h>
// Using directive is necessary to have the overloaded function for socket data
// structure available.
using patchpanel::operator<<;
namespace dns_proxy {
namespace {
constexpr uint32_t kMaxClientTcpConn = 16;
// Retries are delayed by +/- |kRetryDelayJitterMultiplier| times to avoid
// coordinated spikes.
constexpr float kRetryDelayJitterMultiplier = 0.15;
Metrics::QueryError AresStatusMetric(int status) {
switch (status) {
case ARES_SUCCESS:
return Metrics::QueryError::kNone;
case ARES_ENODATA:
return Metrics::QueryError::kNoData;
case ARES_ENOTFOUND:
return Metrics::QueryError::kDomainNotFound;
case ARES_ENOTIMP:
return Metrics::QueryError::kNotImplemented;
case ARES_EREFUSED:
return Metrics::QueryError::kQueryRefused;
case ARES_EFORMERR:
case ARES_EBADQUERY:
case ARES_EBADNAME:
case ARES_EBADFAMILY:
return Metrics::QueryError::kBadQuery;
case ARES_ESERVFAIL:
case ARES_EBADRESP:
return Metrics::QueryError::kOtherServerError;
case ARES_ECONNREFUSED:
return Metrics::QueryError::kConnectionRefused;
case ARES_ETIMEOUT:
return Metrics::QueryError::kQueryTimeout;
default:
return Metrics::QueryError::kOtherClientError;
}
}
Metrics::QueryError CurlCodeMetric(int code) {
switch (code) {
case CURLE_OK:
return Metrics::QueryError::kNone;
case CURLE_UNSUPPORTED_PROTOCOL:
return Metrics::QueryError::kUnsupportedProtocol;
case CURLE_URL_MALFORMAT:
case CURLE_BAD_CONTENT_ENCODING:
return Metrics::QueryError::kBadQuery;
case CURLE_COULDNT_RESOLVE_HOST:
case CURLE_COULDNT_RESOLVE_PROXY:
return Metrics::QueryError::kBadHost;
case CURLE_COULDNT_CONNECT:
case CURLE_SSL_CONNECT_ERROR:
case CURLE_PEER_FAILED_VERIFICATION:
return Metrics::QueryError::kConnectionFailed;
case CURLE_REMOTE_ACCESS_DENIED:
case CURLE_SSL_CLIENTCERT:
return Metrics::QueryError::kConnectionRefused;
case CURLE_OPERATION_TIMEDOUT:
return Metrics::QueryError::kQueryTimeout;
case CURLE_TOO_MANY_REDIRECTS:
return Metrics::QueryError::kTooManyRedirects;
case CURLE_GOT_NOTHING:
return Metrics::QueryError::kNoData;
case CURLE_SEND_ERROR:
case CURLE_WRITE_ERROR:
case CURLE_AGAIN:
return Metrics::QueryError::kSendError;
case CURLE_RECV_ERROR:
case CURLE_READ_ERROR:
return Metrics::QueryError::kReceiveError;
case CURLE_WEIRD_SERVER_REPLY:
case CURLE_RANGE_ERROR:
return Metrics::QueryError::kOtherServerError;
default:
return Metrics::QueryError::kOtherClientError;
}
}
} // namespace
Resolver::SocketFd::SocketFd(int type, int fd)
: type(type), fd(fd), num_retries(0) {
if (type == SOCK_STREAM) {
socklen = 0;
return;
}
socklen = sizeof(src);
}
Resolver::TCPConnection::TCPConnection(
std::unique_ptr<patchpanel::Socket> sock,
const base::RepeatingCallback<void(int, int)>& callback)
: sock(std::move(sock)) {
watcher = base::FileDescriptorWatcher::WatchReadable(
TCPConnection::sock->fd(),
base::BindRepeating(callback, TCPConnection::sock->fd(), SOCK_STREAM));
}
Resolver::Resolver(base::TimeDelta timeout,
base::TimeDelta retry_delay,
int max_num_retries,
int max_concurrent_queries)
: always_on_doh_(false),
doh_enabled_(false),
retry_delay_(retry_delay),
max_num_retries_(max_num_retries),
ares_client_(
new AresClient(timeout, max_num_retries, max_concurrent_queries)),
curl_client_(new DoHCurlClient(timeout, max_concurrent_queries)),
metrics_(new Metrics) {}
Resolver::Resolver(std::unique_ptr<AresClient> ares_client,
std::unique_ptr<DoHCurlClientInterface> curl_client,
std::unique_ptr<Metrics> metrics)
: always_on_doh_(false),
doh_enabled_(false),
ares_client_(std::move(ares_client)),
curl_client_(std::move(curl_client)),
metrics_(std::move(metrics)) {}
bool Resolver::ListenTCP(struct sockaddr* addr) {
auto tcp_src = std::make_unique<patchpanel::Socket>(
addr->sa_family, SOCK_STREAM | SOCK_NONBLOCK);
socklen_t len =
addr->sa_family == AF_INET ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
if (!tcp_src->Bind(addr, len)) {
LOG(ERROR) << "Cannot bind source socket to " << *addr;
return false;
}
if (!tcp_src->Listen(kMaxClientTcpConn)) {
LOG(ERROR) << "Cannot listen on " << *addr;
return false;
}
// Run the accept loop.
LOG(INFO) << "Accepting connections on " << *addr;
tcp_src_watcher_ = base::FileDescriptorWatcher::WatchReadable(
tcp_src->fd(), base::BindRepeating(&Resolver::OnTCPConnection,
weak_factory_.GetWeakPtr()));
tcp_src_ = std::move(tcp_src);
return true;
}
bool Resolver::ListenUDP(struct sockaddr* addr) {
auto udp_src = std::make_unique<patchpanel::Socket>(
addr->sa_family, SOCK_DGRAM | SOCK_NONBLOCK);
socklen_t len =
addr->sa_family == AF_INET ? sizeof(sockaddr_in) : sizeof(sockaddr_in6);
if (!udp_src->Bind(addr, len)) {
LOG(ERROR) << "Cannot bind source socket to " << *addr;
return false;
}
// Start listening.
LOG(INFO) << "Accepting connections on " << *addr;
udp_src_watcher_ = base::FileDescriptorWatcher::WatchReadable(
udp_src->fd(),
base::BindRepeating(&Resolver::OnDNSQuery, weak_factory_.GetWeakPtr(),
udp_src->fd(), SOCK_DGRAM));
udp_src_ = std::move(udp_src);
return true;
}
void Resolver::OnTCPConnection() {
struct sockaddr_storage client_src = {};
socklen_t sockaddr_len = sizeof(client_src);
auto client_conn =
tcp_src_->Accept((struct sockaddr*)&client_src, &sockaddr_len);
if (!client_conn) {
LOG(ERROR) << "Failed to accept TCP connection";
return;
}
tcp_connections_.emplace(
client_conn->fd(),
new TCPConnection(std::move(client_conn),
base::BindRepeating(&Resolver::OnDNSQuery,
weak_factory_.GetWeakPtr())));
}
void Resolver::HandleAresResult(void* ctx,
int status,
unsigned char* msg,
size_t len) {
std::unique_ptr<SocketFd> sock_fd(static_cast<SocketFd*>(ctx));
sock_fd->timer.StopResolve(status == ARES_SUCCESS);
if (metrics_)
metrics_->RecordQueryResult(Metrics::QueryType::kPlainText,
AresStatusMetric(status));
if (status != ARES_SUCCESS) {
LOG(ERROR) << "Failed to do ares lookup: " << ares_strerror(status);
return;
}
ReplyDNS(sock_fd.get(), msg, len);
}
void Resolver::HandleCurlResult(void* ctx,
const DoHCurlClient::CurlResult& res,
unsigned char* msg,
size_t len) {
SocketFd* sock_fd = static_cast<SocketFd*>(ctx);
sock_fd->timer.StopResolve(res.curl_code == CURLE_OK);
if (metrics_)
metrics_->RecordQueryResult(Metrics::QueryType::kDnsOverHttps,
CurlCodeMetric(res.curl_code), res.http_code);
if (res.curl_code != CURLE_OK) {
LOG(ERROR) << "DoH resolution failed: "
<< curl_easy_strerror(res.curl_code);
if (always_on_doh_) {
// TODO(jasongustaman): Send failure reply with RCODE.
delete sock_fd;
return;
}
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindOnce(&Resolver::Resolve, weak_factory_.GetWeakPtr(), sock_fd,
true /* fallback */));
return;
}
switch (res.http_code) {
case kHTTPOk: {
ReplyDNS(sock_fd, msg, len);
delete sock_fd;
return;
}
case kHTTPTooManyRequests: {
if (sock_fd->num_retries >= max_num_retries_) {
LOG(ERROR) << "Failed to resolve hostname, retried " << max_num_retries_
<< " tries";
delete sock_fd;
return;
}
// Add jitter to avoid coordinated spikes of retries.
double rand_multiplier = 1 - base::RandDouble() * 2;
base::TimeDelta retry_delay_jitter =
(1 + rand_multiplier * kRetryDelayJitterMultiplier) * retry_delay_;
// Retry resolving the domain.
base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
FROM_HERE,
base::BindOnce(&Resolver::Resolve, weak_factory_.GetWeakPtr(),
sock_fd, false /* fallback */),
retry_delay_jitter);
sock_fd->num_retries++;
return;
}
default: {
LOG(ERROR) << "Failed to do curl lookup, HTTP status code "
<< res.http_code;
if (always_on_doh_) {
// TODO(jasongustaman): Send failure reply with RCODE.
delete sock_fd;
} else {
base::ThreadTaskRunnerHandle::Get()->PostTask(
FROM_HERE,
base::BindOnce(&Resolver::Resolve, weak_factory_.GetWeakPtr(),
sock_fd, true /* fallback */));
}
return;
}
}
}
void Resolver::ReplyDNS(SocketFd* sock_fd, unsigned char* msg, size_t len) {
sock_fd->timer.StartReply();
// For TCP, DNS messages have an additional 2-bytes header representing
// the length of the query. Add the additional header for the reply.
uint16_t dns_len = htons(len);
struct iovec iov_out[2];
iov_out[0].iov_base = &dns_len;
iov_out[0].iov_len = 2;
// For UDP, skip the additional header. By setting |iov_len| to 0, the
// additional header |dns_len| will not be sent.
if (sock_fd->type == SOCK_DGRAM) {
iov_out[0].iov_len = 0;
}
iov_out[1].iov_base = static_cast<void*>(msg);
iov_out[1].iov_len = len;
struct msghdr hdr = {
.msg_name = nullptr,
.msg_namelen = 0,
.msg_iov = iov_out,
.msg_iovlen = 2,
.msg_control = nullptr,
.msg_controllen = 0,
};
if (sock_fd->type == SOCK_DGRAM) {
hdr.msg_name = &sock_fd->src;
hdr.msg_namelen = sock_fd->socklen;
}
const bool ok = sendmsg(sock_fd->fd, &hdr, 0) >= 0;
sock_fd->timer.StopReply(ok);
if (!ok) {
PLOG(ERROR) << "sendmsg() " << sock_fd->fd << " failed";
}
}
void Resolver::SetNameServers(const std::vector<std::string>& name_servers) {
ares_client_->SetNameServers(name_servers);
curl_client_->SetNameServers(name_servers);
}
void Resolver::SetDoHProviders(const std::vector<std::string>& doh_providers,
bool always_on_doh) {
always_on_doh_ = always_on_doh;
doh_enabled_ = !doh_providers.empty();
curl_client_->SetDoHProviders(doh_providers);
}
void Resolver::OnDNSQuery(int fd, int type) {
// Initialize SocketFd to carry necessary data. |sock_fd| must be freed when
// it is done being used.
SocketFd* sock_fd = new SocketFd(type, fd);
// Metrics will be recorded automatically when this object is deleted.
sock_fd->timer.set_metrics(metrics_.get());
size_t buf_size;
struct sockaddr* src;
switch (type) {
case SOCK_DGRAM:
sock_fd->msg = sock_fd->buf;
buf_size = kDNSBufSize;
src = reinterpret_cast<struct sockaddr*>(&sock_fd->src);
break;
case SOCK_STREAM:
// For TCP, DNS has an additional 2-bytes header representing the length
// of the query. Move the receiving buffer, so it is 4-bytes aligned.
sock_fd->msg = sock_fd->buf + 2;
buf_size = kDNSBufSize - 2;
src = nullptr;
break;
default:
LOG(DFATAL) << "Unexpected socket type: " << type;
return;
}
sock_fd->timer.StartReceive();
sock_fd->len =
recvfrom(fd, sock_fd->msg, buf_size, 0, src, &sock_fd->socklen);
// Assume success - on failure, the correct value will be recorded.
sock_fd->timer.StopReceive(true);
if (sock_fd->len < 0) {
sock_fd->timer.StopReceive(false);
PLOG(WARNING) << "recvfrom failed";
delete sock_fd;
return;
}
// Handle TCP connection closed.
if (sock_fd->len == 0) {
sock_fd->timer.StopReceive(false);
delete sock_fd;
tcp_connections_.erase(fd);
return;
}
// For TCP, DNS have an additional 2-bytes header representing the length of
// the query. Trim the additional header to be used by CURL or Ares.
if (type == SOCK_STREAM && sock_fd->len > 2) {
sock_fd->msg += 2;
sock_fd->len -= 2;
}
Resolve(sock_fd);
}
void Resolver::Resolve(SocketFd* sock_fd, bool fallback) {
if (doh_enabled_ && !fallback) {
sock_fd->timer.StartResolve(true);
if (curl_client_->Resolve(sock_fd->msg, sock_fd->len,
base::BindRepeating(&Resolver::HandleCurlResult,
weak_factory_.GetWeakPtr()),
reinterpret_cast<void*>(sock_fd))) {
return;
}
sock_fd->timer.StopResolve(false);
}
if (!always_on_doh_) {
sock_fd->timer.StartResolve();
if (ares_client_->Resolve(
reinterpret_cast<const unsigned char*>(sock_fd->msg), sock_fd->len,
base::BindRepeating(&Resolver::HandleAresResult,
weak_factory_.GetWeakPtr()),
reinterpret_cast<void*>(sock_fd))) {
return;
}
sock_fd->timer.StopResolve(false);
}
// Construct and send a response indicating that there is a failure.
patchpanel::DnsResponse response =
ConstructServFailResponse(sock_fd->msg, sock_fd->len);
ReplyDNS(sock_fd,
reinterpret_cast<unsigned char*>(response.io_buffer()->data()),
response.io_buffer_size());
// |sock_fd| pointer must be deleted when the request associated with the
// pointer is done. Normally, the pointer is deleted after c-ares or CURL
// finish handling the request, `HandleAresResult(...)` or
// `HandleCurlResult(...)`. However, we need to do it here because there is an
// error when starting the request of c-ares or CURL resulting in no query
// sent to the name servers, completing the request by sending a failure
// response.
delete sock_fd;
}
patchpanel::DnsResponse Resolver::ConstructServFailResponse(const char* msg,
int len) {
// Construct a DNS query from the message buffer.
base::Optional<patchpanel::DnsQuery> query;
if (len > 0 && len <= dns_proxy::kDNSBufSize) {
scoped_refptr<patchpanel::IOBufferWithSize> query_buf =
base::MakeRefCounted<patchpanel::IOBufferWithSize>(len);
memcpy(query_buf->data(), msg, len);
query = patchpanel::DnsQuery(query_buf);
}
// Set the query id as 0 if the query is invalid.
uint16_t query_id = 0;
if (query.has_value() && query->Parse(len)) {
query_id = query->id();
} else {
query.reset();
}
// Returns RCODE SERVFAIL response corresponding to the query.
patchpanel::DnsResponse response(query_id, false /* is_authoritative */,
{} /* answers */, {} /* authority_records */,
{} /* additional_records */, query,
patchpanel::dns_protocol::kRcodeSERVFAIL);
return response;
}
} // namespace dns_proxy