blob: ded8465cbde9d96e575bd79c697fee3a22169afd [file] [log] [blame] [edit]
// Copyright 2023 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net-base/socket.h"
#include <fcntl.h>
#include <linux/netlink.h>
#include <cstdint>
#include <memory>
#include <utility>
#include <base/files/scoped_file.h>
#include <base/functional/bind.h>
#include <base/test/task_environment.h>
#include <gtest/gtest.h>
#include "net-base/byte_utils.h"
#include "net-base/mock_socket.h"
namespace net_base {
namespace {
constexpr int netlink_family = NETLINK_GENERIC;
constexpr uint32_t netlink_groups_mask = 0;
using testing::_;
using testing::Return;
TEST(Socket, CreateFromFd) {
int sv[2];
ASSERT_EQ(socketpair(AF_UNIX, SOCK_RAW, 0, sv), 0);
close(sv[1]);
base::ScopedFD fd(sv[0]);
int raw_fd = fd.get();
auto socket = Socket::CreateFromFd(std::move(fd));
EXPECT_EQ(socket->Get(), raw_fd);
}
TEST(Socket, CreateFromFdInvalid) {
auto socket = Socket::CreateFromFd(base::ScopedFD());
EXPECT_EQ(socket, nullptr);
}
TEST(Socket, Release) {
int sv[2];
ASSERT_EQ(socketpair(AF_UNIX, SOCK_RAW, 0, sv), 0);
close(sv[1]);
base::ScopedFD fd(sv[0]);
int raw_fd = fd.get();
// Socket::Release() returns the raw fd, and not close the fd.
auto socket = Socket::CreateFromFd(std::move(fd));
EXPECT_EQ(Socket::Release(std::move(socket)), raw_fd);
EXPECT_EQ(close(raw_fd), 0);
}
class MockCallback {
public:
MOCK_METHOD(void, OnSocketReadable, (), ());
};
TEST(Socket, SetReadableCallback) {
constexpr std::string_view msg = "hello, world";
base::test::TaskEnvironment task_environment{
base::test::TaskEnvironment::MainThreadType::IO};
int sv[2];
ASSERT_EQ(socketpair(AF_UNIX, SOCK_RAW, 0, sv), 0);
std::unique_ptr<Socket> write_socket =
Socket::CreateFromFd(base::ScopedFD(sv[0]));
std::unique_ptr<Socket> read_socket =
Socket::CreateFromFd(base::ScopedFD(sv[1]));
MockCallback callback;
// The callback should be called once after write_socket sends data.
EXPECT_CALL(callback, OnSocketReadable()).WillOnce([&read_socket, msg]() {
std::vector<uint8_t> buf;
read_socket->RecvMessage(&buf);
EXPECT_EQ(byte_utils::ByteStringFromBytes(buf), msg);
});
read_socket->SetReadableCallback(base::BindRepeating(
&MockCallback::OnSocketReadable, base::Unretained(&callback)));
write_socket->Send(msg);
task_environment.RunUntilIdle();
testing::Mock::VerifyAndClearExpectations(&callback);
// After unsetting the callback, the callback should not be triggered.
EXPECT_CALL(callback, OnSocketReadable()).Times(0);
read_socket->UnsetReadableCallback();
write_socket->Send(msg);
task_environment.RunUntilIdle();
}
TEST(Socket, ReadFromStream) {
// Make sure this is long enough to exercise the chunking behavior.
// So we need to use a value of at least 1025.
constexpr int kMsgInts = 1500;
std::vector<uint8_t> msg;
msg.reserve(kMsgInts * 4);
unsigned int seed = 0;
for (int i = 0; i < kMsgInts; i++) {
int r = rand_r(&seed);
msg.push_back(r & 0xFF);
msg.push_back((r >> 8) & 0xFF);
msg.push_back((r >> 16) & 0xFF);
msg.push_back((r >> 24) & 0xFF);
}
base::test::TaskEnvironment task_environment{
base::test::TaskEnvironment::MainThreadType::IO};
int sv[2];
ASSERT_EQ(socketpair(AF_UNIX, SOCK_STREAM, 0, sv), 0);
std::unique_ptr<Socket> write_socket =
Socket::CreateFromFd(base::ScopedFD(sv[0]));
std::unique_ptr<Socket> read_socket =
Socket::CreateFromFd(base::ScopedFD(sv[1]));
std::vector<uint8_t> buf;
read_socket->SetReadableCallback(base::BindRepeating(
[](Socket* sock, std::vector<uint8_t>* out,
base::RepeatingClosure quit_closure) {
sock->RecvMessage(out);
quit_closure.Run();
},
read_socket.get(), &buf, task_environment.QuitClosure()));
write_socket->Send(msg);
task_environment.RunUntilQuit();
EXPECT_EQ(msg, buf);
}
class SocketUnderTest : public Socket {
public:
SocketUnderTest()
: Socket(base::ScopedFD(open("/dev/null", O_RDONLY)), SOCK_RAW) {}
~SocketUnderTest() override = default;
// Mocks RecvFrom() to verify RecvMessage().
MOCK_METHOD(std::optional<size_t>,
RecvFrom,
(base::span<uint8_t>, int, struct sockaddr*, socklen_t*),
(const, override));
};
TEST(Socket, RecvMessageFailed) {
SocketUnderTest socket;
EXPECT_CALL(socket, RecvFrom).WillOnce(Return(std::nullopt));
std::vector<uint8_t> message;
EXPECT_EQ(socket.RecvMessage(&message), false);
}
TEST(Socket, RecvMessageSuccess) {
const std::vector<uint8_t> recv_data = {1, 3, 5, 7, 9};
SocketUnderTest socket;
EXPECT_CALL(socket, RecvFrom(_, MSG_TRUNC | MSG_PEEK, _, _))
.WillOnce(Return(recv_data.size()));
EXPECT_CALL(socket, RecvFrom(_, 0, _, _))
.WillOnce([&](base::span<uint8_t> buf, int flags,
struct sockaddr* src_addr,
socklen_t* addrlen) -> std::optional<size_t> {
if (buf.size() != recv_data.size()) {
return std::nullopt;
}
memcpy(buf.data(), recv_data.data(), buf.size());
return buf.size();
});
std::vector<uint8_t> message;
EXPECT_EQ(socket.RecvMessage(&message), true);
EXPECT_EQ(message, recv_data);
}
MATCHER_P(IsNetlinkAddr, groups, "") {
const struct sockaddr_nl* socket_address =
reinterpret_cast<const struct sockaddr_nl*>(arg);
return socket_address->nl_family == AF_NETLINK &&
socket_address->nl_groups == groups;
}
// Mock Create() method only to verify the behavior of CreateNetlink().
class MockSocketFactory : public SocketFactory {
public:
MockSocketFactory() = default;
~MockSocketFactory() override = default;
MOCK_METHOD(std::unique_ptr<Socket>,
Create,
(int domain, int type, int protocol),
(override));
};
TEST(SocketFactory, CreateNetlinkSuccess) {
MockSocketFactory socket_factory;
EXPECT_CALL(socket_factory,
Create(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, netlink_family))
.WillOnce([]() {
auto socket = std::make_unique<MockSocket>();
EXPECT_CALL(*socket,
SetReceiveBuffer(SocketFactory::kNetlinkReceiveBufferSize))
.WillOnce(Return(true));
EXPECT_CALL(*socket, Bind(IsNetlinkAddr(netlink_groups_mask),
sizeof(struct sockaddr_nl)))
.WillOnce(Return(true));
return socket;
});
EXPECT_NE(socket_factory.CreateNetlink(netlink_family, netlink_groups_mask),
nullptr);
}
TEST(SocketFactory, CreateNetlinkNotSetReceiveBuffer) {
MockSocketFactory socket_factory;
EXPECT_CALL(socket_factory, Create).WillOnce([]() {
auto socket = std::make_unique<MockSocket>();
EXPECT_CALL(*socket,
SetReceiveBuffer(SocketFactory::kNetlinkReceiveBufferSize))
.Times(0);
EXPECT_CALL(*socket, Bind(IsNetlinkAddr(netlink_groups_mask),
sizeof(struct sockaddr_nl)))
.WillOnce(Return(true));
return socket;
});
EXPECT_NE(socket_factory.CreateNetlink(netlink_family, netlink_groups_mask,
std::nullopt),
nullptr);
}
TEST(SocketFactory, CreateNetlinkSocketFail) {
MockSocketFactory socket_factory;
EXPECT_CALL(socket_factory, Create).WillOnce(Return(nullptr));
EXPECT_EQ(socket_factory.CreateNetlink(netlink_family, netlink_groups_mask),
nullptr);
}
TEST(SocketFactory, CreateNetlinkBindFail) {
MockSocketFactory socket_factory;
EXPECT_CALL(socket_factory, Create).WillOnce([]() {
auto socket = std::make_unique<MockSocket>();
EXPECT_CALL(*socket,
SetReceiveBuffer(SocketFactory::kNetlinkReceiveBufferSize))
.WillOnce(Return(true));
EXPECT_CALL(*socket, Bind).WillOnce(Return(false));
return socket;
});
EXPECT_EQ(socket_factory.CreateNetlink(netlink_family, netlink_groups_mask),
nullptr);
}
} // namespace
} // namespace net_base