blob: bc70af335350aa811ac4148fb53ef4e6111b6b34 [file] [log] [blame] [edit]
// Copyright 2018 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net-base/netlink_socket.h"
#include <linux/netlink.h>
#include <algorithm>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include <base/containers/span.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "net-base/byte_utils.h"
#include "net-base/mock_socket.h"
#include "net-base/netlink_message.h"
using testing::_;
using testing::Invoke;
using testing::Return;
using testing::Test;
namespace net_base {
class NetlinkSocketTest : public Test {
public:
void SetUp() override {
auto mock_socket_factory = std::make_unique<MockSocketFactory>();
EXPECT_CALL(*mock_socket_factory, CreateNetlink(NETLINK_GENERIC, 0, _))
.WillOnce([&]() {
auto socket = std::make_unique<MockSocket>();
mock_socket_ = socket.get();
return socket;
});
netlink_socket_ =
NetlinkSocket::CreateWithSocketFactory(std::move(mock_socket_factory));
EXPECT_NE(netlink_socket_, nullptr);
}
protected:
std::unique_ptr<NetlinkSocket> netlink_socket_;
MockSocket* mock_socket_;
};
class FakeSocketRead {
public:
explicit FakeSocketRead(base::span<const uint8_t> next_read_string)
: next_read_string_(next_read_string.begin(), next_read_string.end()) {}
// Copies |len| bytes of |next_read_string_| into |buf| and clears
// |next_read_string_|.
std::optional<size_t> FakeSuccessfulRead(base::span<uint8_t> buf,
int flags,
struct sockaddr* src_addr,
socklen_t* addrlen) {
if (buf.empty()) {
return std::nullopt;
}
const size_t read_bytes = std::min(buf.size(), next_read_string_.size());
memcpy(buf.data(), next_read_string_.data(), read_bytes);
next_read_string_.clear();
return read_bytes;
}
private:
std::vector<uint8_t> next_read_string_;
};
MATCHER_P(SpanEq, value, "") {
return arg.size() == value.size() &&
memcmp(arg.data(), value.data(), arg.size()) == 0;
}
TEST_F(NetlinkSocketTest, InitBrokenSocketTest) {
auto mock_socket_factory = std::make_unique<MockSocketFactory>();
EXPECT_CALL(*mock_socket_factory, CreateNetlink(NETLINK_GENERIC, 0, _))
.WillOnce(Return(nullptr));
auto netlink_socket =
NetlinkSocket::CreateWithSocketFactory(std::move(mock_socket_factory));
EXPECT_EQ(netlink_socket, nullptr);
}
TEST_F(NetlinkSocketTest, SendMessageTest) {
const std::vector<uint8_t> message =
net_base::byte_utils::ByteStringToBytes("This text is really arbitrary");
// Good Send.
EXPECT_CALL(*mock_socket_, Send(SpanEq(message), 0))
.WillOnce(Return(message.size()));
EXPECT_TRUE(netlink_socket_->SendMessage(message));
// Short Send.
EXPECT_CALL(*mock_socket_, Send(SpanEq(message), 0))
.WillOnce(Return(message.size() - 3));
EXPECT_FALSE(netlink_socket_->SendMessage(message));
// Bad Send.
EXPECT_CALL(*mock_socket_, Send(SpanEq(message), 0))
.WillOnce(Return(std::nullopt));
EXPECT_FALSE(netlink_socket_->SendMessage(message));
}
TEST_F(NetlinkSocketTest, SequenceNumberTest) {
// Just a sequence number.
const uint32_t arbitrary_number = 42;
netlink_socket_->set_sequence_number_for_test(arbitrary_number);
EXPECT_EQ(arbitrary_number + 1, netlink_socket_->GetSequenceNumber());
// Make sure we don't go to |NetlinkMessage::kBroadcastSequenceNumber|.
netlink_socket_->set_sequence_number_for_test(
NetlinkMessage::kBroadcastSequenceNumber);
EXPECT_NE(NetlinkMessage::kBroadcastSequenceNumber,
netlink_socket_->GetSequenceNumber());
}
} // namespace net_base