| // Copyright 2021 The Chromium OS Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "dns-proxy/ares_client.h" |
| |
| #include <algorithm> |
| #include <utility> |
| |
| #include <base/bind.h> |
| #include <base/containers/contains.h> |
| #include <base/logging.h> |
| #include <base/strings/string_util.h> |
| #include <base/threading/thread_task_runner_handle.h> |
| |
| namespace dns_proxy { |
| |
| AresClient::State::State(AresClient* client, |
| ares_channel channel, |
| const QueryCallback& callback, |
| void* ctx) |
| : client(client), channel(channel), callback(callback), ctx(ctx) {} |
| |
| AresClient::AresClient(base::TimeDelta timeout, |
| int max_num_retries, |
| int max_concurrent_queries) |
| : timeout_(timeout), |
| max_num_retries_(max_num_retries), |
| max_concurrent_queries_(max_concurrent_queries) { |
| if (ares_library_init(ARES_LIB_INIT_ALL) != ARES_SUCCESS) { |
| LOG(DFATAL) << "Failed to initialize ares library"; |
| } |
| } |
| |
| AresClient::~AresClient() { |
| // Whenever ares_destroy is called, AresCallback will be called with status |
| // equal to ARES_EDESTRUCTION. This callback ensures that the states of the |
| // queries are cleared properly. |
| for (const auto& channel : channels_inflight_) { |
| ares_destroy(channel); |
| } |
| ares_library_cleanup(); |
| } |
| |
| void AresClient::OnFileCanReadWithoutBlocking(ares_channel channel, |
| ares_socket_t socket_fd) { |
| ares_process_fd(channel, socket_fd, ARES_SOCKET_BAD); |
| UpdateWatchers(channel); |
| } |
| |
| void AresClient::OnFileCanWriteWithoutBlocking(ares_channel channel, |
| ares_socket_t socket_fd) { |
| ares_process_fd(channel, ARES_SOCKET_BAD, socket_fd); |
| UpdateWatchers(channel); |
| } |
| |
| void AresClient::UpdateWatchers(ares_channel channel) { |
| ares_socket_t sockets[ARES_GETSOCK_MAXNUM]; |
| int action_bits = ares_getsock(channel, sockets, ARES_GETSOCK_MAXNUM); |
| |
| auto read_watchers = read_watchers_.find(channel); |
| auto write_watchers = write_watchers_.find(channel); |
| if (read_watchers == read_watchers_.end() || |
| write_watchers == write_watchers_.end()) { |
| return; |
| } |
| |
| // Clear the watchers and rebuild it. This is necessary because ares does not |
| // provide a utility to notify unused sockets. |
| read_watchers->second.clear(); |
| write_watchers->second.clear(); |
| for (int i = 0; i < ARES_GETSOCK_MAXNUM; i++) { |
| if (ARES_GETSOCK_READABLE(action_bits, i)) { |
| read_watchers->second.emplace_back( |
| base::FileDescriptorWatcher::WatchReadable( |
| sockets[i], |
| base::BindRepeating(&AresClient::OnFileCanReadWithoutBlocking, |
| weak_factory_.GetWeakPtr(), channel, |
| sockets[i]))); |
| } |
| if (ARES_GETSOCK_WRITABLE(action_bits, i)) { |
| write_watchers->second.emplace_back( |
| base::FileDescriptorWatcher::WatchReadable( |
| sockets[i], |
| base::BindRepeating(&AresClient::OnFileCanWriteWithoutBlocking, |
| weak_factory_.GetWeakPtr(), channel, |
| sockets[i]))); |
| } |
| } |
| } |
| |
| void AresClient::SetNameServers(const std::vector<std::string>& name_servers) { |
| name_servers_ = base::JoinString(name_servers, ","); |
| num_name_servers_ = name_servers.size(); |
| } |
| |
| void AresClient::AresCallback( |
| void* ctx, int status, int timeouts, unsigned char* msg, int len) { |
| State* state = static_cast<State*>(ctx); |
| // The query is cancelled in-flight. Cleanup the state. |
| if (status == ARES_ECANCELLED || status == ARES_EDESTRUCTION) { |
| delete state; |
| return; |
| } |
| |
| auto buf = std::make_unique<unsigned char[]>(len); |
| memcpy(buf.get(), msg, len); |
| // Handle the result outside this function to avoid undefined behaviors. |
| base::ThreadTaskRunnerHandle::Get()->PostTask( |
| FROM_HERE, base::BindOnce(&AresClient::HandleResult, |
| state->client->weak_factory_.GetWeakPtr(), |
| state, status, std::move(buf), len)); |
| } |
| |
| void AresClient::HandleResult(State* state, |
| int status, |
| std::unique_ptr<uint8_t[]> msg, |
| int len) { |
| // Set state as unique pointer to force cleanup, the state must be destroyed |
| // in this function. |
| std::unique_ptr<State> scoped_state(state); |
| |
| // `HandleResult(...)` may be called even after ares channel is destroyed |
| // This happens if a query is completed while queries are being cancelled. |
| // On such case, do nothing, the state will be deleted through unique pointer. |
| if (!base::Contains(channels_inflight_, state->channel)) { |
| return; |
| } |
| |
| // Ares will return 0 if no queries are active on the channel. |
| // |read_fds| and |write_fds| are unused. |
| fd_set read_fds, write_fds; |
| int nfds = ares_fds(state->channel, &read_fds, &write_fds); |
| |
| // Run the callback if the current request is the first successful request |
| // or the current request is the last request. |
| if (status != ARES_SUCCESS && nfds > 0) { |
| return; |
| } |
| state->callback.Run(state->ctx, status, msg.get(), len); |
| msg.reset(); |
| |
| // Cancel other queries and destroy the channel. Whenever ares_destroy is |
| // called, AresCallback will be called with status equal to ARES_EDESTRUCTION. |
| // This callback ensures that the states of the in-flight queries ares cleared |
| // properly. |
| channels_inflight_.erase(state->channel); |
| read_watchers_.erase(state->channel); |
| write_watchers_.erase(state->channel); |
| ares_destroy(state->channel); |
| } |
| |
| void AresClient::ResetTimeout(ares_channel channel) { |
| // Check for timeout if the channel is still available. |
| if (!base::Contains(channels_inflight_, channel)) { |
| return; |
| } |
| ares_process_fd(channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); |
| |
| struct timeval max_tv, ret_tv; |
| struct timeval* tv; |
| max_tv.tv_sec = timeout_.InMilliseconds() / 1000; |
| max_tv.tv_usec = (timeout_.InMilliseconds() % 1000) * 1000; |
| if ((tv = ares_timeout(channel, &max_tv, &ret_tv)) == NULL) { |
| LOG(ERROR) << "Failed to get timeout"; |
| return; |
| } |
| int timeout_ms = tv->tv_sec * 1000 + tv->tv_usec / 1000; |
| base::ThreadTaskRunnerHandle::Get()->PostDelayedTask( |
| FROM_HERE, |
| base::BindRepeating(&AresClient::ResetTimeout, weak_factory_.GetWeakPtr(), |
| channel), |
| base::TimeDelta::FromMilliseconds(timeout_ms)); |
| } |
| |
| ares_channel AresClient::InitChannel() { |
| struct ares_options options; |
| memset(&options, 0, sizeof(options)); |
| int optmask = 0; |
| |
| // Set option timeout. |
| optmask |= ARES_OPT_TIMEOUTMS; |
| options.timeout = timeout_.InMilliseconds(); |
| |
| // Set maximum number of retries. |
| optmask |= ARES_OPT_TRIES; |
| options.tries = max_num_retries_; |
| |
| // Perform round-robin selection of name servers. This enables Resolve(...) |
| // to resolve using multiple servers concurrently. |
| optmask |= ARES_OPT_ROTATE; |
| |
| ares_channel channel; |
| if (ares_init_options(&channel, &options, optmask) != ARES_SUCCESS) { |
| LOG(ERROR) << "Failed to initialize ares_channel"; |
| ares_destroy(channel); |
| return nullptr; |
| } |
| |
| if (ares_set_servers_csv(channel, name_servers_.c_str()) != ARES_SUCCESS) { |
| LOG(ERROR) << "Failed to set ares name servers"; |
| ares_destroy(channel); |
| return nullptr; |
| } |
| |
| // Start timeout handler. |
| channels_inflight_.emplace(channel); |
| ResetTimeout(channel); |
| return channel; |
| } |
| |
| bool AresClient::Resolve(const unsigned char* msg, |
| size_t len, |
| const QueryCallback& callback, |
| void* ctx) { |
| if (name_servers_.empty()) { |
| LOG(ERROR) << "Name servers must not be empty"; |
| return false; |
| } |
| ares_channel channel = InitChannel(); |
| if (!channel) { |
| return false; |
| } |
| // Query multiple name servers concurrently. Selection of name servers is |
| // done implicitly through round robin selection. This is enabled by ares |
| // option ARES_OPT_ROTATE. |
| for (int i = 0; i < std::min(num_name_servers_, max_concurrent_queries_); |
| i++) { |
| State* state = new State(this, channel, callback, ctx); |
| ares_send(channel, msg, len, &AresClient::AresCallback, state); |
| } |
| |
| // Set up file descriptor watchers. |
| read_watchers_.emplace( |
| channel, |
| std::vector<std::unique_ptr<base::FileDescriptorWatcher::Controller>>()); |
| write_watchers_.emplace( |
| channel, |
| std::vector<std::unique_ptr<base::FileDescriptorWatcher::Controller>>()); |
| UpdateWatchers(channel); |
| return true; |
| } |
| } // namespace dns_proxy |