blob: 718ba5b9e5e762c4ef59c406371cf8f0448bddb2 [file] [log] [blame]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "dns-proxy/resolver.h"
#include <utility>
#include <vector>
#include <base/test/task_environment.h>
#include <base/time/time.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "dns-proxy/ares_client.h"
#include "dns-proxy/doh_curl_client.h"
using testing::_;
using testing::ElementsAreArray;
using testing::Return;
namespace dns_proxy {
namespace {
const std::vector<std::string> kTestNameServers{"8.8.8.8"};
const std::vector<std::string> kTestDoHProviders{
"https://dns.google/dns-query"};
constexpr base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(3);
constexpr int32_t kMaxNumRetries = 1;
class MockDoHCurlClient : public DoHCurlClient {
public:
MockDoHCurlClient() : DoHCurlClient(kTimeout, kDefaultMaxConcurrentQueries) {}
~MockDoHCurlClient() = default;
MOCK_METHOD4(
Resolve,
bool(const char* msg, int len, const QueryCallback& callback, void* ctx));
MOCK_METHOD1(SetNameServers,
void(const std::vector<std::string>& name_servers));
MOCK_METHOD1(SetDoHProviders,
void(const std::vector<std::string>& doh_providers));
};
class MockAresClient : public AresClient {
public:
MockAresClient()
: AresClient(kTimeout, kMaxNumRetries, kDefaultMaxConcurrentQueries) {}
~MockAresClient() = default;
MOCK_METHOD4(Resolve,
bool(const unsigned char* msg,
size_t len,
const QueryCallback& callback,
void* ctx));
MOCK_METHOD1(SetNameServers,
void(const std::vector<std::string>& name_servers));
};
} // namespace
class ResolverTest : public testing::Test {
protected:
void SetUp() override {
std::unique_ptr<MockAresClient> scoped_ares_client(new MockAresClient());
std::unique_ptr<MockDoHCurlClient> scoped_curl_client(
new MockDoHCurlClient());
ares_client_ = scoped_ares_client.get();
curl_client_ = scoped_curl_client.get();
resolver_ = std::make_unique<Resolver>(std::move(scoped_ares_client),
std::move(scoped_curl_client));
}
base::test::TaskEnvironment task_environment_;
MockAresClient* ares_client_;
MockDoHCurlClient* curl_client_;
std::unique_ptr<Resolver> resolver_;
};
TEST_F(ResolverTest, SetNameServers) {
EXPECT_CALL(*ares_client_, SetNameServers(kTestNameServers)).Times(1);
EXPECT_CALL(*curl_client_, SetNameServers(kTestNameServers)).Times(1);
EXPECT_CALL(*curl_client_, SetDoHProviders(kTestDoHProviders)).Times(1);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
}
TEST_F(ResolverTest, Resolve_DNSDoHServers) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).WillOnce(Return(true));
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
resolver_->Resolve(&sock_fd);
}
TEST_F(ResolverTest, Resolve_DNSServers) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).WillOnce(Return(true));
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
resolver_->Resolve(&sock_fd);
}
TEST_F(ResolverTest, Resolve_DNSDoHServersFallback) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).WillOnce(Return(true));
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
resolver_->Resolve(&sock_fd, true);
}
TEST_F(ResolverTest, CurlResult_CURLFail) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).WillOnce(Return(true));
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res(CURLE_COULDNT_CONNECT, 0 /* http_code */,
0 /* timeout */);
resolver_->HandleCurlResult(&sock_fd, res, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, CurlResult_HTTPError) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).WillOnce(Return(true));
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res(CURLE_OK, 403 /* http_code */, 0 /* timeout */);
resolver_->HandleCurlResult(&sock_fd, res, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, CurlResult_SuccessNoRetry) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res(CURLE_OK, 200 /* http_code */, 0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, CurlResult_FailNoRetry) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders, true /* always_on */);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res1(CURLE_OUT_OF_MEMORY, 200 /* http_code */,
0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res1, nullptr, 0);
task_environment_.RunUntilIdle();
// |sock_fd| should be freed by now.
sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res2(CURLE_OK, 403 /* http_code */,
0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res2, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, CurlResult_FailTooManyRetries) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
sock_fd->num_retries = INT_MAX;
DoHCurlClient::CurlResult res(CURLE_OK, 429 /* http_code */, 0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, HandleAresResult_Success) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_DGRAM, 0);
resolver_->HandleAresResult(static_cast<void*>(sock_fd), ARES_SUCCESS,
nullptr, 0);
}
TEST_F(ResolverTest, HandleAresResult_Fail) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_DGRAM, 0);
resolver_->HandleAresResult(static_cast<void*>(sock_fd), ARES_SUCCESS,
nullptr, 0);
}
TEST_F(ResolverTest, ConstructServFailResponse_ValidQuery) {
const char kDnsQuery[] = {'J', 'G', '\x01', ' ', '\x00', '\x01',
'\x00', '\x00', '\x00', '\x00', '\x00', '\x01',
'\x06', 'g', 'o', 'o', 'g', 'l',
'e', '\x03', 'c', 'o', 'm', '\x00',
'\x00', '\x01', '\x00', '\x01'};
const char kServFailResponse[] = {
'J', 'G', '\x80', '\x02', '\x00', '\x01', '\x00',
'\x00', '\x00', '\x00', '\x00', '\x00', '\x06', 'g',
'o', 'o', 'g', 'l', 'e', '\x03', 'c',
'o', 'm', '\x00', '\x00', '\x01', '\x00', '\x01'};
patchpanel::DnsResponse response =
resolver_->ConstructServFailResponse(kDnsQuery, sizeof(kDnsQuery));
std::vector<char> response_data(
response.io_buffer()->data(),
response.io_buffer()->data() + response.io_buffer_size());
EXPECT_THAT(response_data, ElementsAreArray(kServFailResponse));
}
TEST_F(ResolverTest, ConstructServFailResponse_BadLength) {
const char kDnsQuery[] = {'J', 'G', '\x01', ' ', '\x00', '\x01',
'\x00', '\x00', '\x00', '\x00', '\x00', '\x01',
'\x06', 'g', 'o', 'o', 'g', 'l',
'e', '\x03', 'c', 'o', 'm', '\x00',
'\x00', '\x01', '\x00', '\x01'};
const char kServFailResponse[] = {'\x00', '\x00', '\x80', '\x02',
'\x00', '\x00', '\x00', '\x00',
'\x00', '\x00', '\x00', '\x00'};
patchpanel::DnsResponse response =
resolver_->ConstructServFailResponse(kDnsQuery, -1);
std::vector<char> response_data(
response.io_buffer()->data(),
response.io_buffer()->data() + response.io_buffer_size());
EXPECT_THAT(response_data, ElementsAreArray(kServFailResponse));
}
TEST_F(ResolverTest, ConstructServFailResponse_BadQuery) {
const char kDnsQuery[] = {'g', 'o', 'o', 'g', 'l',
'e', '\x03', 'c', 'o', 'm'};
const char kServFailResponse[] = {'\x00', '\x00', '\x80', '\x02',
'\x00', '\x00', '\x00', '\x00',
'\x00', '\x00', '\x00', '\x00'};
patchpanel::DnsResponse response =
resolver_->ConstructServFailResponse(kDnsQuery, sizeof(kDnsQuery));
std::vector<char> response_data(
response.io_buffer()->data(),
response.io_buffer()->data() + response.io_buffer_size());
EXPECT_THAT(response_data, ElementsAreArray(kServFailResponse));
}
} // namespace dns_proxy