| // Copyright 2023 The ChromiumOS Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "net-base/dns_client.h" |
| |
| #include <fcntl.h> |
| #include <netdb.h> |
| |
| #include <string> |
| #include <string_view> |
| #include <utility> |
| |
| #include <base/files/file_util.h> |
| #include <base/files/scoped_file.h> |
| #include <base/logging.h> |
| #include <base/time/time.h> |
| #include <base/test/task_environment.h> |
| #include <gmock/gmock.h> |
| #include <gtest/gtest.h> |
| |
| #include "net-base/ares_interface.h" |
| |
| namespace net_base { |
| namespace { |
| |
| using ::testing::_; |
| using ::testing::AllOf; |
| using ::testing::Field; |
| using ::testing::Ge; |
| using ::testing::Pointee; |
| using ::testing::StrEq; |
| using ::testing::StrictMock; |
| |
| using Error = DNSClient::Error; |
| using Result = DNSClient::Result; |
| |
| class FakeAres : public AresInterface { |
| public: |
| ~FakeAres(); |
| |
| int init_options(ares_channel* channelptr, |
| struct ares_options* options, |
| int optmask) override; |
| |
| void destroy(ares_channel channel) override; |
| |
| void set_local_dev(ares_channel channel, const char* local_dev_name) override; |
| |
| void gethostbyname(ares_channel channel, |
| const char* name, |
| int family, |
| ares_host_callback callback, |
| void* arg) override; |
| |
| struct timeval* timeout(ares_channel channel, |
| struct timeval* maxtv, |
| struct timeval* tv) override; |
| |
| int getsock(ares_channel channel, |
| ares_socket_t* socks, |
| int numsocks) override; |
| |
| void process_fd(ares_channel channel, |
| ares_socket_t read_fd, |
| ares_socket_t write_fd) override; |
| |
| int set_servers_csv(ares_channel channel, const char* servers) override; |
| |
| // The client of FakeAres will get the event that socket is readable or |
| // writable. |
| void TriggerReadReady(); |
| void TriggerWriteReady(); |
| |
| // The next process_fd() call will invoke the callback with the given |
| // parameters. |
| void InvokeCallbackOnNextProcessFD(int status, std::vector<IPAddress> addrs); |
| |
| private: |
| struct GethostbynameParams { |
| int family = 0; |
| void* arg = nullptr; |
| ares_host_callback callback = nullptr; |
| }; |
| |
| struct CallbackResult { |
| int status = 0; |
| std::vector<IPAddress> addrs; |
| }; |
| |
| ares_channel CreateChannel() { |
| CHECK_EQ(channel_, nullptr) << "Channel has been created"; |
| // Note that the value doesn't important here. It will only be used as an |
| // identifier so it only needs to be unique. |
| channel_ = this; |
| return reinterpret_cast<ares_channel>(channel_); |
| } |
| |
| void CheckChannel(ares_channel channel) { |
| CHECK_EQ(channel, channel_) << "Input channel does not match"; |
| } |
| |
| void VerifyReadFD(int fd); |
| void BlockWriteFD(); |
| |
| void* channel_ = nullptr; |
| |
| std::unique_ptr<GethostbynameParams> gethostbyname_params_; |
| |
| base::ScopedFD read_fd_local_; |
| base::ScopedFD read_fd_remote_; |
| base::ScopedFD write_fd_local_; |
| base::ScopedFD write_fd_remote_; |
| |
| std::optional<CallbackResult> callback_result_; |
| }; |
| |
| FakeAres::~FakeAres() { |
| CHECK_EQ(channel_, nullptr) |
| << "Channel is not nullptr, perhaps no call to ares_destroy()?"; |
| } |
| |
| int FakeAres::init_options(ares_channel* channelptr, |
| struct ares_options* options, |
| int optmask) { |
| LOG(INFO) << __func__ << ": " << channelptr; |
| *channelptr = CreateChannel(); |
| return ARES_SUCCESS; |
| } |
| |
| void FakeAres::destroy(ares_channel channel) { |
| CheckChannel(channel); |
| if (gethostbyname_params_) { |
| gethostbyname_params_->callback(gethostbyname_params_->arg, |
| ARES_EDESTRUCTION, |
| /*timeouts=*/0, /*hostent=*/nullptr); |
| } |
| |
| channel_ = nullptr; |
| gethostbyname_params_ = nullptr; |
| read_fd_local_.reset(); |
| read_fd_remote_.reset(); |
| write_fd_local_.reset(); |
| write_fd_remote_.reset(); |
| } |
| |
| void FakeAres::set_local_dev(ares_channel channel, const char* local_dev_name) { |
| CheckChannel(channel); |
| } |
| |
| void FakeAres::gethostbyname(ares_channel channel, |
| const char* name, |
| int family, |
| ares_host_callback callback, |
| void* arg) { |
| CheckChannel(channel); |
| CHECK_EQ(gethostbyname_params_, nullptr) << "Callback has been set"; |
| gethostbyname_params_ = std::make_unique<GethostbynameParams>(); |
| gethostbyname_params_->family = family; |
| gethostbyname_params_->callback = callback; |
| gethostbyname_params_->arg = arg; |
| |
| int fds[2]; |
| CHECK_EQ(pipe2(fds, 0), 0); |
| read_fd_local_.reset(fds[0]); |
| read_fd_remote_.reset(fds[1]); |
| CHECK_EQ(pipe2(fds, 0), 0); |
| write_fd_local_.reset(fds[1]); |
| write_fd_remote_.reset(fds[0]); |
| |
| // Block the write fd at first so that the client won't get a ready event at |
| // the beginning. |
| BlockWriteFD(); |
| } |
| |
| struct timeval* FakeAres::timeout(ares_channel channel, |
| struct timeval* maxtv, |
| struct timeval* tv) { |
| return maxtv; |
| } |
| |
| int FakeAres::getsock(ares_channel channel, |
| ares_socket_t* socks, |
| int numsocks) { |
| CheckChannel(channel); |
| CHECK_GE(numsocks, 2); |
| socks[0] = read_fd_local_.get(); |
| socks[1] = write_fd_local_.get(); |
| // (1 << 0): socket 0 is readable; |
| // (1 << (ARES_GETSOCK_MAXNUM + 1)): socket 1 is writable; |
| return (1 << 0) | (1 << (ARES_GETSOCK_MAXNUM + 1)); |
| } |
| |
| void FakeAres::process_fd(ares_channel channel, |
| ares_socket_t read_fd, |
| ares_socket_t write_fd) { |
| CheckChannel(channel); |
| if (read_fd != ARES_SOCKET_BAD) { |
| CHECK_EQ(read_fd, read_fd_local_.get()); |
| VerifyReadFD(read_fd); |
| } |
| if (write_fd != ARES_SOCKET_BAD) { |
| CHECK_EQ(write_fd, write_fd_local_.get()); |
| BlockWriteFD(); |
| } |
| if (!callback_result_.has_value()) { |
| return; |
| } |
| |
| CHECK(gethostbyname_params_); |
| std::vector<std::vector<uint8_t>> addrs_in_bytes; |
| std::vector<char*> ptrs_to_addrs_in_bytes; |
| for (const auto& ip : callback_result_->addrs) { |
| addrs_in_bytes.push_back(ip.ToBytes()); |
| ptrs_to_addrs_in_bytes.push_back( |
| reinterpret_cast<char*>(addrs_in_bytes.back().data())); |
| } |
| ptrs_to_addrs_in_bytes.push_back(nullptr); |
| |
| struct hostent ent; |
| |
| // Not using these fields in the implementation now, just ignore them. |
| ent.h_name = nullptr; |
| ent.h_aliases = nullptr; |
| |
| ent.h_addrtype = gethostbyname_params_->family; |
| ent.h_length = gethostbyname_params_->family == AF_INET ? 4 : 16; |
| ent.h_addr_list = ptrs_to_addrs_in_bytes.data(); |
| |
| gethostbyname_params_->callback(gethostbyname_params_->arg, |
| callback_result_->status, |
| /*timeouts=*/0, &ent); |
| } |
| |
| int FakeAres::set_servers_csv(ares_channel channel, const char* servers) { |
| CheckChannel(channel); |
| return ARES_SUCCESS; |
| } |
| |
| // The string used to trigger and verify the read fd behavior. The content can |
| // be any. |
| constexpr std::string_view kFDContent = "0"; |
| |
| // Triggers the read ready event by sending some content on the pipe. |
| void FakeAres::TriggerReadReady() { |
| CHECK(read_fd_remote_.is_valid()) << "Read fd is not ready"; |
| CHECK(base::WriteFileDescriptor(read_fd_remote_.get(), kFDContent)); |
| } |
| |
| void FakeAres::VerifyReadFD(int fd) { |
| char buf[kFDContent.size() + 1]; |
| CHECK(base::ReadFromFD(fd, buf, kFDContent.size())) |
| << "Failed to read from fd"; |
| CHECK_EQ(std::string_view(buf, kFDContent.size()), kFDContent); |
| } |
| |
| // Note that 4096 is minimum buffer size of a pipe. |
| constexpr int kPipeBufferSize = 4096; |
| |
| // Triggers the write ready event by consuming the content in the pipe so that |
| // it's no longer blocking. |
| void FakeAres::TriggerWriteReady() { |
| CHECK(write_fd_remote_.is_valid()) << "Write fd is not ready"; |
| static char buf[kPipeBufferSize]; |
| CHECK(base::ReadFromFD(write_fd_remote_.get(), buf, kPipeBufferSize)); |
| } |
| |
| void FakeAres::BlockWriteFD() { |
| CHECK_EQ(kPipeBufferSize, |
| fcntl(write_fd_remote_.get(), F_SETPIPE_SZ, kPipeBufferSize)); |
| CHECK(base::WriteFileDescriptor(write_fd_local_.get(), |
| std::string(kPipeBufferSize, 'a'))); |
| } |
| |
| void FakeAres::InvokeCallbackOnNextProcessFD(int status, |
| std::vector<IPAddress> addrs) { |
| callback_result_ = CallbackResult{status, std::move(addrs)}; |
| } |
| |
| class DNSClientTest : public testing::Test { |
| protected: |
| DNSClientTest() = default; |
| ~DNSClientTest() = default; |
| |
| MOCK_METHOD(void, Callback, (const Result&), ()); |
| MOCK_METHOD(void, CallbackWithDuration, (base::TimeDelta, const Result&), ()); |
| DNSClient::Callback GetCallback() { |
| return base::BindOnce(&DNSClientTest::Callback, base::Unretained(this)); |
| } |
| DNSClient::CallbackWithDuration GetCallbackWithDuration() { |
| return base::BindOnce(&DNSClientTest::CallbackWithDuration, |
| base::Unretained(this)); |
| } |
| |
| base::test::TaskEnvironment task_env_{ |
| base::test::TaskEnvironment::TimeSource::MOCK_TIME, |
| base::test::TaskEnvironment::MainThreadType::IO}; |
| }; |
| |
| TEST_F(DNSClientTest, IPv4WriteReadAndReturnSuccess) { |
| FakeAres fake_ares; |
| const auto addrs = {IPAddress::CreateFromString("192.168.1.1").value(), |
| IPAddress::CreateFromString("192.168.1.2").value()}; |
| |
| auto client = |
| DNSClientFactory().Resolve(IPFamily::kIPv4, "test-url", GetCallback(), |
| /*options=*/{}, &fake_ares); |
| |
| fake_ares.TriggerWriteReady(); |
| task_env_.RunUntilIdle(); |
| |
| fake_ares.TriggerReadReady(); |
| fake_ares.InvokeCallbackOnNextProcessFD(ARES_SUCCESS, addrs); |
| EXPECT_CALL(*this, Callback(Result(addrs))); |
| task_env_.RunUntilIdle(); |
| } |
| |
| TEST_F(DNSClientTest, IPv6WriteReadAndReturnSuccess) { |
| FakeAres fake_ares; |
| const auto addrs = {IPAddress::CreateFromString("fd00::1").value(), |
| IPAddress::CreateFromString("fd00::2").value()}; |
| |
| auto client = |
| DNSClientFactory().Resolve(IPFamily::kIPv6, "test-url", GetCallback(), |
| /*options=*/{}, &fake_ares); |
| |
| fake_ares.TriggerWriteReady(); |
| task_env_.RunUntilIdle(); |
| |
| fake_ares.TriggerReadReady(); |
| fake_ares.InvokeCallbackOnNextProcessFD(ARES_SUCCESS, addrs); |
| EXPECT_CALL(*this, Callback(Result(addrs))); |
| task_env_.RunUntilIdle(); |
| } |
| |
| TEST_F(DNSClientTest, MultipleWriteReadAndReturnSuccess) { |
| FakeAres fake_ares; |
| const auto addrs = {IPAddress::CreateFromString("192.168.1.1").value(), |
| IPAddress::CreateFromString("192.168.1.2").value()}; |
| |
| auto client = |
| DNSClientFactory().Resolve(IPFamily::kIPv4, "test-url", GetCallback(), |
| /*options=*/{}, &fake_ares); |
| |
| fake_ares.TriggerWriteReady(); |
| fake_ares.TriggerReadReady(); |
| task_env_.RunUntilIdle(); |
| |
| fake_ares.TriggerWriteReady(); |
| fake_ares.TriggerReadReady(); |
| task_env_.RunUntilIdle(); |
| |
| fake_ares.TriggerReadReady(); |
| fake_ares.InvokeCallbackOnNextProcessFD(ARES_SUCCESS, addrs); |
| EXPECT_CALL(*this, Callback(Result(addrs))); |
| task_env_.RunUntilIdle(); |
| } |
| |
| TEST_F(DNSClientTest, WriteReadAndReturnError) { |
| FakeAres fake_ares; |
| |
| auto client = |
| DNSClientFactory().Resolve(IPFamily::kIPv4, "test-url", GetCallback(), |
| /*options=*/{}, &fake_ares); |
| |
| fake_ares.TriggerWriteReady(); |
| fake_ares.TriggerReadReady(); |
| fake_ares.InvokeCallbackOnNextProcessFD(ARES_ENODATA, {}); |
| EXPECT_CALL(*this, Callback(Result(base::unexpected(Error::kNoData)))); |
| task_env_.RunUntilIdle(); |
| } |
| |
| TEST_F(DNSClientTest, WriteAndTimeout) { |
| FakeAres fake_ares; |
| |
| DNSClient::Options opts = { |
| .timeout = base::Seconds(1), |
| }; |
| auto client = DNSClientFactory().Resolve(IPFamily::kIPv4, "test-url", |
| GetCallback(), opts, &fake_ares); |
| |
| fake_ares.TriggerWriteReady(); |
| task_env_.RunUntilIdle(); |
| |
| EXPECT_CALL(*this, Callback(Result(base::unexpected(Error::kTimedOut)))); |
| task_env_.FastForwardBy(base::Seconds(2)); |
| } |
| |
| TEST_F(DNSClientTest, WriteAndDestroyObject) { |
| FakeAres fake_ares; |
| |
| auto client = |
| DNSClientFactory().Resolve(IPFamily::kIPv4, "test-url", GetCallback(), |
| /*options=*/{}, &fake_ares); |
| |
| fake_ares.TriggerWriteReady(); |
| task_env_.RunUntilIdle(); |
| |
| // No callback should be invoked in this case. |
| EXPECT_CALL(*this, Callback).Times(0); |
| client.reset(); |
| task_env_.FastForwardUntilNoTasksRemain(); |
| } |
| |
| // Only need to mock several functions. |
| class MockAres : public FakeAres { |
| public: |
| MockAres() { |
| ON_CALL(*this, init_options) |
| .WillByDefault([this](ares_channel* channelptr, |
| struct ares_options* options, int optmask) { |
| return this->FakeAres::init_options(channelptr, options, optmask); |
| }); |
| ON_CALL(*this, set_local_dev) |
| .WillByDefault( |
| [this](ares_channel channel, const char* local_dev_name) { |
| return this->FakeAres::set_local_dev(channel, local_dev_name); |
| }); |
| ON_CALL(*this, set_servers_csv) |
| .WillByDefault([this](ares_channel channel, const char* servers) { |
| return this->FakeAres::set_servers_csv(channel, servers); |
| }); |
| } |
| |
| MOCK_METHOD(int, |
| init_options, |
| (ares_channel * channelptr, |
| struct ares_options* options, |
| int optmask), |
| (override)); |
| MOCK_METHOD(void, |
| set_local_dev, |
| (ares_channel channel, const char* local_dev_name), |
| (override)); |
| MOCK_METHOD(int, |
| set_servers_csv, |
| (ares_channel channel, const char* servers), |
| (override)); |
| }; |
| |
| TEST_F(DNSClientTest, ResolveWithOptions) { |
| StrictMock<MockAres> mock_ares; |
| |
| DNSClient::Options test_opts = { |
| .number_of_tries = 5, |
| .per_query_initial_timeout = base::Seconds(10), |
| .interface = "wlan0", |
| .name_server = IPAddress::CreateFromString("1.2.3.4").value(), |
| }; |
| |
| EXPECT_CALL(mock_ares, |
| init_options(_, |
| Pointee(AllOf(Field(&ares_options::tries, 5), |
| Field(&ares_options::timeout, 10000))), |
| ARES_OPT_TIMEOUTMS | ARES_OPT_TRIES)); |
| EXPECT_CALL(mock_ares, set_local_dev(_, StrEq("wlan0"))); |
| EXPECT_CALL(mock_ares, set_servers_csv(_, StrEq("1.2.3.4"))); |
| |
| auto client = DNSClientFactory().Resolve( |
| IPFamily::kIPv4, "test-url", GetCallback(), test_opts, &mock_ares); |
| } |
| |
| TEST_F(DNSClientTest, ResolveWithoutOptions) { |
| StrictMock<MockAres> mock_ares; |
| |
| EXPECT_CALL(mock_ares, |
| init_options(_, |
| Pointee(AllOf(Field(&ares_options::tries, 0), |
| Field(&ares_options::timeout, 0))), |
| /*opt_masks=*/0)); |
| |
| auto client = DNSClientFactory().Resolve( |
| IPFamily::kIPv4, "test-url", GetCallback(), /*options=*/{}, &mock_ares); |
| } |
| |
| TEST_F(DNSClientTest, ReturnSuccessWithDuration) { |
| FakeAres fake_ares; |
| const auto addrs = {IPAddress::CreateFromString("192.168.1.1").value()}; |
| |
| auto client = DNSClientFactory().Resolve(IPFamily::kIPv4, "test-url", |
| GetCallbackWithDuration(), |
| /*options=*/{}, &fake_ares); |
| |
| task_env_.FastForwardBy(base::Milliseconds(150)); |
| fake_ares.TriggerWriteReady(); |
| task_env_.RunUntilIdle(); |
| |
| fake_ares.TriggerReadReady(); |
| fake_ares.InvokeCallbackOnNextProcessFD(ARES_SUCCESS, addrs); |
| EXPECT_CALL(*this, |
| CallbackWithDuration(Ge(base::Milliseconds(150)), Result(addrs))); |
| task_env_.RunUntilIdle(); |
| } |
| |
| TEST_F(DNSClientTest, ReturnErrorWithDuration) { |
| FakeAres fake_ares; |
| |
| auto client = DNSClientFactory().Resolve(IPFamily::kIPv4, "test-url", |
| GetCallbackWithDuration(), |
| /*options=*/{}, &fake_ares); |
| |
| task_env_.FastForwardBy(base::Milliseconds(200)); |
| fake_ares.TriggerWriteReady(); |
| fake_ares.TriggerReadReady(); |
| fake_ares.InvokeCallbackOnNextProcessFD(ARES_ENODATA, {}); |
| EXPECT_CALL(*this, |
| CallbackWithDuration(Ge(base::Milliseconds(200)), |
| Result(base::unexpected(Error::kNoData)))); |
| task_env_.RunUntilIdle(); |
| } |
| |
| TEST_F(DNSClientTest, TimeoutWithDuration) { |
| FakeAres fake_ares; |
| |
| DNSClient::Options opts = { |
| .timeout = base::Seconds(1), |
| }; |
| auto client = DNSClientFactory().Resolve( |
| IPFamily::kIPv4, "test-url", GetCallbackWithDuration(), opts, &fake_ares); |
| |
| fake_ares.TriggerWriteReady(); |
| task_env_.RunUntilIdle(); |
| |
| EXPECT_CALL(*this, |
| CallbackWithDuration(Ge(base::Seconds(1)), |
| Result(base::unexpected(Error::kTimedOut)))); |
| task_env_.FastForwardBy(base::Seconds(2)); |
| } |
| |
| } // namespace |
| } // namespace net_base |