// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "dns-proxy/doh_curl_client.h"

#include <utility>

#include <base/bind.h>
#include <base/strings/string_util.h>
#include <base/threading/thread_task_runner_handle.h>

namespace dns_proxy {
namespace {
constexpr char kLinuxUserAgent[] =
    "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (kHTML, like Gecko) "
    "Chrome/7.0.38.09.132 Safari/537.36";
constexpr std::array<const char*, 2> kDoHHeaderList{
    {"Accept: application/dns-message",
     "Content-Type: application/dns-message"}};
}  // namespace

DoHCurlClient::CurlResult::CurlResult(CURLcode curl_code,
                                      int64_t http_code,
                                      int64_t retry_delay_ms)
    : curl_code(curl_code),
      http_code(http_code),
      retry_delay_ms(retry_delay_ms) {}

DoHCurlClient::State::State(CURL* curl,
                            const QueryCallback& callback,
                            void* ctx,
                            int request_id)
    : curl(curl),
      callback(callback),
      ctx(ctx),
      header_list(nullptr),
      request_id(request_id) {}

DoHCurlClient::State::~State() {
  curl_easy_cleanup(curl);
  curl_slist_free_all(header_list);
}

void DoHCurlClient::State::RunCallback(CURLMsg* curl_msg, int64_t http_code) {
  // TODO(jasongustaman): Use HTTP 429, Retry-After header value.
  CurlResult res(curl_msg->data.result, http_code, 0 /* retry_delay_ms */);
  callback.Run(ctx, res, response.data(), response.size());
}

void DoHCurlClient::State::SetResponse(char* msg, size_t len) {
  if (len <= 0) {
    LOG(ERROR) << "Unexpected length: " << len;
    return;
  }
  response.insert(response.end(), msg, msg + len);
}

DoHCurlClient::DoHCurlClient(base::TimeDelta timeout,
                             int max_concurrent_queries)
    : timeout_seconds_(timeout.InSeconds()),
      max_concurrent_queries_(max_concurrent_queries) {
  // Initialize CURL.
  curl_global_init(CURL_GLOBAL_DEFAULT);
  curlm_ = curl_multi_init();

  // Set socket callback to `SocketCallback(...)`. This function will be called
  // whenever a CURL socket state is changed. DoHCurlClient class |this| will
  // passed as a parameter of the callback.
  curl_multi_setopt(curlm_, CURLMOPT_SOCKETDATA, this);
  curl_multi_setopt(curlm_, CURLMOPT_SOCKETFUNCTION,
                    &DoHCurlClient::SocketCallback);

  // Set timer callback to `TimerCallback(...)`. This function will be called
  // whenever a timeout change happened. DoHCurlClient class |this| will be
  // passed as a parameter of the callback.
  curl_multi_setopt(curlm_, CURLMOPT_TIMERDATA, this);
  curl_multi_setopt(curlm_, CURLMOPT_TIMERFUNCTION,
                    &DoHCurlClient::TimerCallback);
}

DoHCurlClient::~DoHCurlClient() {
  // Cancel all in-flight queries.
  for (const auto& requests : requests_) {
    CancelRequest(requests.second);
  }
  curl_global_cleanup();
}

void DoHCurlClient::HandleResult(CURLMsg* curl_msg) {
  // `HandleResult(...)` may be called even after `CancelRequest(...)` is
  // called. This happens if a query is completed while queries are being
  // cancelled. On such case, do nothing.
  if (!base::Contains(states_, curl_msg->easy_handle)) {
    return;
  }

  CURL* curl = curl_msg->easy_handle;
  State* state = states_[curl].get();

  int64_t http_code = 0;
  curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);

  // Run the callback if the current request is the first successful request
  // or the current request is the last request (noted by the number of request
  // with the same |request_id| is 1).
  if (http_code == kHTTPOk || requests_[state->request_id].size() == 1) {
    state->RunCallback(curl_msg, http_code);
    CancelRequest(state->request_id);
    return;
  }
  // TODO(jasongustaman): Get and save curl metrics.
}

void DoHCurlClient::CheckMultiInfo() {
  CURLMsg* curl_msg = nullptr;
  int msgs_left = 0;
  while ((curl_msg = curl_multi_info_read(curlm_, &msgs_left))) {
    if (curl_msg->msg != CURLMSG_DONE) {
      continue;
    }
    HandleResult(curl_msg);
  }
}

void DoHCurlClient::OnFileCanReadWithoutBlocking(curl_socket_t socket_fd) {
  int still_running;
  CURLMcode rc = curl_multi_socket_action(curlm_, socket_fd, CURL_CSELECT_IN,
                                          &still_running);
  if (rc != CURLM_OK) {
    LOG(INFO) << "Failed to read from socket: " << curl_multi_strerror(rc);
    return;
  }
  CheckMultiInfo();
}

void DoHCurlClient::OnFileCanWriteWithoutBlocking(curl_socket_t socket_fd) {
  int still_running;
  CURLMcode rc = curl_multi_socket_action(curlm_, socket_fd, CURL_CSELECT_OUT,
                                          &still_running);
  if (rc != CURLM_OK) {
    LOG(INFO) << "Failed to write to socket: " << curl_multi_strerror(rc);
    return;
  }
  CheckMultiInfo();
}

void DoHCurlClient::AddReadWatcher(curl_socket_t socket_fd) {
  if (!base::Contains(read_watchers_, socket_fd)) {
    read_watchers_.emplace(
        socket_fd,
        base::FileDescriptorWatcher::WatchReadable(
            socket_fd,
            base::BindRepeating(&DoHCurlClient::OnFileCanReadWithoutBlocking,
                                weak_factory_.GetWeakPtr(), socket_fd)));
  }
}

void DoHCurlClient::AddWriteWatcher(curl_socket_t socket_fd) {
  if (!base::Contains(write_watchers_, socket_fd)) {
    write_watchers_.emplace(
        socket_fd,
        base::FileDescriptorWatcher::WatchReadable(
            socket_fd,
            base::BindRepeating(&DoHCurlClient::OnFileCanWriteWithoutBlocking,
                                weak_factory_.GetWeakPtr(), socket_fd)));
  }
}

void DoHCurlClient::RemoveWatcher(curl_socket_t socket_fd) {
  read_watchers_.erase(socket_fd);
  write_watchers_.erase(socket_fd);
}

int DoHCurlClient::SocketCallback(
    CURL* easy, curl_socket_t socket_fd, int what, void* userp, void* socketp) {
  DoHCurlClient* client = static_cast<DoHCurlClient*>(userp);
  switch (what) {
    case CURL_POLL_IN:
      client->AddReadWatcher(socket_fd);
      return 0;
    case CURL_POLL_OUT:
      client->AddWriteWatcher(socket_fd);
      return 0;
    case CURL_POLL_INOUT:
      client->AddReadWatcher(socket_fd);
      client->AddWriteWatcher(socket_fd);
      return 0;
    case CURL_POLL_REMOVE:
      client->RemoveWatcher(socket_fd);
      return 0;
    default:
      return 0;
  }
}

void DoHCurlClient::TimeoutCallback() {
  int still_running;
  curl_multi_socket_action(curlm_, CURL_SOCKET_TIMEOUT, 0, &still_running);
  CheckMultiInfo();
}

int DoHCurlClient::TimerCallback(CURLM* multi,
                                 int64_t timeout_ms,
                                 void* userp) {
  DoHCurlClient* client = static_cast<DoHCurlClient*>(userp);
  if (timeout_ms > 0) {
    base::ThreadTaskRunnerHandle::Get()->PostDelayedTask(
        FROM_HERE,
        base::BindRepeating(&DoHCurlClient::TimeoutCallback,
                            base::Unretained(client)),
        base::TimeDelta::FromMilliseconds(timeout_ms));
  } else if (timeout_ms == 0) {
    client->TimeoutCallback();
  }
  return 0;
}

size_t DoHCurlClient::WriteCallback(char* ptr,
                                    size_t size,
                                    size_t nmemb,
                                    void* userdata) {
  State* state = static_cast<State*>(userdata);
  size_t len = size * nmemb;
  state->SetResponse(ptr, len);
  return len;
}

size_t DoHCurlClient::HeaderCallback(void* data,
                                     size_t size,
                                     size_t nitems,
                                     void* userp) {
  State* state = static_cast<State*>(userp);
  size_t len = size * nitems;
  std::string header(static_cast<char*>(data), len);
  state->header.emplace_back(header);
  return len;
}

void DoHCurlClient::SetNameServers(
    const std::vector<std::string>& name_servers) {
  name_servers_ = base::JoinString(name_servers, ",");
}

void DoHCurlClient::SetDoHProviders(
    const std::vector<std::string>& doh_providers) {
  doh_providers_ = doh_providers;
}

void DoHCurlClient::CancelRequest(const std::set<State*>& states) {
  for (const auto& state : states) {
    curl_multi_remove_handle(curlm_, state->curl);
    states_.erase(state->curl);
  }
}

void DoHCurlClient::CancelRequest(int request_id) {
  auto requests = requests_.find(request_id);
  if (requests == requests_.end()) {
    return;
  }
  // Cancel in-flight queries and delete the state.
  CancelRequest(requests->second);
  requests_.erase(request_id);
}

std::unique_ptr<DoHCurlClient::State> DoHCurlClient::InitCurl(
    const std::string& doh_provider,
    const char* msg,
    int len,
    const QueryCallback& callback,
    void* ctx) {
  CURL* curl;
  curl = curl_easy_init();
  if (!curl) {
    LOG(ERROR) << "Failed to initialize curl";
    return nullptr;
  }

  // Allocate a state for the request.
  std::unique_ptr<State> state =
      std::make_unique<State>(curl, callback, ctx, next_request_id_);

  // Set the target URL which is the DoH provider to query to.
  curl_easy_setopt(curl, CURLOPT_URL, doh_provider.c_str());

  // Set the DNS name servers to resolve the URL(s) / DoH provider(s).
  // This uses ares and will be done asynchronously.
  curl_easy_setopt(curl, CURLOPT_DNS_SERVERS, name_servers_.c_str());

  // Set the HTTP header to the needed DoH header. The stored value needs to
  // be released when query is finished.
  for (int i = 0; i < kDoHHeaderList.size(); i++) {
    state.get()->header_list =
        curl_slist_append(state.get()->header_list, kDoHHeaderList[i]);
  }
  curl_easy_setopt(curl, CURLOPT_HTTPHEADER, state.get()->header_list);

  // Stores the data to be sent through HTTP POST and its length.
  curl_easy_setopt(curl, CURLOPT_POSTFIELDS, msg);
  curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, len);

  // Set the user agent for the query.
  curl_easy_setopt(curl, CURLOPT_USERAGENT, kLinuxUserAgent);

  // Ignore signals SIGPIPE to be sent when the other end of CURL socket is
  // closed.
  curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 0);

  // Set timeout of the query.
  curl_easy_setopt(curl, CURLOPT_TIMEOUT, timeout_seconds_);

  // Set the callback to be called whenever CURL got a response. The data
  // needs to be copied to the write data.
  curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, &DoHCurlClient::WriteCallback);
  curl_easy_setopt(curl, CURLOPT_WRITEDATA, state.get());

  // Handle redirection automatically.
  curl_easy_setopt(curl, CURLOPT_REDIR_PROTOCOLS, 1L);
  curl_easy_setopt(curl, CURLOPT_POSTREDIR, CURL_REDIR_POST_ALL);

  return state;
}

bool DoHCurlClient::Resolve(const char* msg,
                            int len,
                            const QueryCallback& callback,
                            void* ctx) {
  if (name_servers_.empty() || doh_providers_.empty()) {
    LOG(DFATAL) << "DNS and DoH server must not be empty";
    return false;
  }

  std::set<State*> requests;
  int num_concurrent_queries = 0;
  for (const auto& doh_provider : doh_providers_) {
    std::unique_ptr<State> state =
        InitCurl(doh_provider, msg, len, callback, ctx);
    if (!state.get()) {
      continue;
    }
    State* state_ptr = state.get();

    // Create state structure to store required data of each query.
    states_.emplace(state_ptr->curl, std::move(state));
    requests.emplace(state_ptr);

    // Runs the query asynchronously.
    curl_multi_add_handle(curlm_, state_ptr->curl);

    // Queries at most |max_concurrent_queries_| times concurrently.
    num_concurrent_queries++;
    if (num_concurrent_queries >= max_concurrent_queries_) {
      break;
    }
  }

  if (requests.empty()) {
    LOG(ERROR) << "No requests for query";
    return false;
  }

  // Store the concurrent requests and increment |next_request_id_|.
  requests_.emplace(next_request_id_, requests);
  next_request_id_++;
  return true;
}
}  // namespace dns_proxy
