blob: 4022261863fb6f7c0d2bd2f463ff61f460c56f24 [file] [log] [blame]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "dns-proxy/resolver.h"
#include <utility>
#include <vector>
#include <base/test/task_environment.h>
#include <base/time/time.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "dns-proxy/ares_client.h"
#include "dns-proxy/doh_curl_client.h"
using testing::_;
using testing::Return;
namespace dns_proxy {
namespace {
const std::vector<std::string> kTestNameServers{"8.8.8.8"};
const std::vector<std::string> kTestDoHProviders{
"https://dns.google/dns-query"};
constexpr base::TimeDelta kTimeout = base::TimeDelta::FromSeconds(3);
constexpr int32_t kMaxNumRetries = 1;
class MockDoHCurlClient : public DoHCurlClient {
public:
MockDoHCurlClient() : DoHCurlClient(kTimeout, kDefaultMaxConcurrentQueries) {}
~MockDoHCurlClient() = default;
MOCK_METHOD4(
Resolve,
bool(const char* msg, int len, const QueryCallback& callback, void* ctx));
MOCK_METHOD1(SetNameServers,
void(const std::vector<std::string>& name_servers));
MOCK_METHOD1(SetDoHProviders,
void(const std::vector<std::string>& doh_providers));
};
class MockAresClient : public AresClient {
public:
MockAresClient()
: AresClient(kTimeout, kMaxNumRetries, kDefaultMaxConcurrentQueries) {}
~MockAresClient() = default;
MOCK_METHOD4(Resolve,
bool(const unsigned char* msg,
size_t len,
const QueryCallback& callback,
void* ctx));
MOCK_METHOD1(SetNameServers,
void(const std::vector<std::string>& name_servers));
};
} // namespace
class ResolverTest : public testing::Test {
protected:
void SetUp() override {
std::unique_ptr<MockAresClient> scoped_ares_client(new MockAresClient());
std::unique_ptr<MockDoHCurlClient> scoped_curl_client(
new MockDoHCurlClient());
ares_client_ = scoped_ares_client.get();
curl_client_ = scoped_curl_client.get();
resolver_ = std::make_unique<Resolver>(std::move(scoped_ares_client),
std::move(scoped_curl_client));
}
base::test::TaskEnvironment task_environment_;
MockAresClient* ares_client_;
MockDoHCurlClient* curl_client_;
std::unique_ptr<Resolver> resolver_;
};
TEST_F(ResolverTest, SetNameServers) {
EXPECT_CALL(*ares_client_, SetNameServers(kTestNameServers)).Times(1);
EXPECT_CALL(*curl_client_, SetNameServers(kTestNameServers)).Times(1);
EXPECT_CALL(*curl_client_, SetDoHProviders(kTestDoHProviders)).Times(1);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
}
TEST_F(ResolverTest, Resolve_DNSDoHServers) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).WillOnce(Return(true));
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
resolver_->Resolve(&sock_fd);
}
TEST_F(ResolverTest, Resolve_DNSServers) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(1);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
resolver_->Resolve(&sock_fd);
}
TEST_F(ResolverTest, Resolve_DNSDoHServersFallback) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(1);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
resolver_->Resolve(&sock_fd, true);
}
TEST_F(ResolverTest, CurlResult_CURLFail) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(1);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res(CURLE_COULDNT_CONNECT, 0 /* http_code */,
0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(&sock_fd), res, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, CurlResult_HTTPError) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(1);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd sock_fd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res(CURLE_OK, 403 /* http_code */, 0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(&sock_fd), res, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, CurlResult_SuccessNoRetry) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res(CURLE_OK, 200 /* http_code */, 0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, CurlResult_FailNoRetry) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders, true /* always_on */);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res1(CURLE_OUT_OF_MEMORY, 200 /* http_code */,
0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res1, nullptr, 0);
task_environment_.RunUntilIdle();
// |sock_fd| should be freed by now.
sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
DoHCurlClient::CurlResult res2(CURLE_OK, 403 /* http_code */,
0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res2, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, CurlResult_FailTooManyRetries) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
resolver_->SetDoHProviders(kTestDoHProviders);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_STREAM, 0);
sock_fd->num_retries = INT_MAX;
DoHCurlClient::CurlResult res(CURLE_OK, 429 /* http_code */, 0 /* timeout */);
resolver_->HandleCurlResult(static_cast<void*>(sock_fd), res, nullptr, 0);
task_environment_.RunUntilIdle();
}
TEST_F(ResolverTest, HandleAresResult_Success) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_DGRAM, 0);
resolver_->HandleAresResult(static_cast<void*>(sock_fd), ARES_SUCCESS,
nullptr, 0);
}
TEST_F(ResolverTest, HandleAresResult_Fail) {
EXPECT_CALL(*ares_client_, Resolve(_, _, _, _)).Times(0);
EXPECT_CALL(*curl_client_, Resolve(_, _, _, _)).Times(0);
resolver_->SetNameServers(kTestNameServers);
Resolver::SocketFd* sock_fd = new Resolver::SocketFd(SOCK_DGRAM, 0);
resolver_->HandleAresResult(static_cast<void*>(sock_fd), ARES_SUCCESS,
nullptr, 0);
}
} // namespace dns_proxy