blob: 6fd3923f70e2b1b23e6a41620dac9e1a0f900cec [file] [log] [blame] [edit]
// Copyright 2020 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net-base/socket_forwarder.h"
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <memory>
#include <optional>
#include <vector>
#include <base/functional/callback.h>
#include <base/test/test_future.h>
#include <base/task/single_thread_task_executor.h>
#include <brillo/message_loops/base_message_loop.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "net-base/socket.h"
using testing::Each;
namespace net_base {
namespace {
// SocketForwarder reads blocks of 4096 bytes.
constexpr size_t kDataSize = 5000;
// Does a blocking read on |socket| until it fills up the |buf|.
bool Read(const net_base::Socket& socket, base::span<char> buf) {
while (buf.size() > 0) {
const std::optional<size_t> bytes = socket.RecvFrom(buf);
if (!bytes.has_value()) {
return false;
}
buf = buf.subspan(*bytes);
}
return true;
}
} // namespace
class SocketForwarderTest : public ::testing::Test {
void SetUp() override {
int fds0[2], fds1[2];
ASSERT_NE(-1, socketpair(AF_UNIX, SOCK_STREAM, 0 /* protocol */, fds0));
ASSERT_NE(-1, socketpair(AF_UNIX, SOCK_STREAM, 0 /* protocol */, fds1));
peer0_ = net_base::Socket::CreateFromFd(base::ScopedFD(fds0[0]));
peer1_ = net_base::Socket::CreateFromFd(base::ScopedFD(fds1[0]));
forwarder_ = std::make_unique<SocketForwarder>(
"test", net_base::Socket::CreateFromFd(base::ScopedFD(fds0[1])),
net_base::Socket::CreateFromFd((base::ScopedFD(fds1[1]))));
}
protected:
std::unique_ptr<net_base::Socket> peer0_;
std::unique_ptr<net_base::Socket> peer1_;
// Forwards data between |peer0_| and |peer1_|.
std::unique_ptr<SocketForwarder> forwarder_;
base::SingleThreadTaskExecutor task_executor_{base::MessagePumpType::IO};
brillo::BaseMessageLoop brillo_loop_{task_executor_.task_runner()};
};
TEST_F(SocketForwarderTest, ForwardDataAndClose) {
base::test::TestFuture<void> signal;
forwarder_->SetStopQuitClosureForTesting(signal.GetCallback());
forwarder_->Start();
std::vector<char> msg(kDataSize, 1);
EXPECT_EQ(peer0_->Send(msg), kDataSize);
EXPECT_EQ(peer1_->Send(msg), kDataSize);
// Close both sockets for writing.
EXPECT_NE(shutdown(peer0_->Get(), SHUT_WR), -1);
EXPECT_NE(shutdown(peer1_->Get(), SHUT_WR), -1);
EXPECT_TRUE(signal.Wait());
EXPECT_FALSE(forwarder_->IsRunning());
// Verify that all the data has been forwarded to the peers.
std::vector<char> expected_data_peer0(kDataSize);
std::vector<char> expected_data_peer1(kDataSize);
EXPECT_TRUE(Read(*peer1_, expected_data_peer1));
EXPECT_TRUE(Read(*peer0_, expected_data_peer0));
EXPECT_THAT(expected_data_peer0, Each(1));
EXPECT_THAT(expected_data_peer1, Each(1));
}
TEST_F(SocketForwarderTest, PeerSignalEPOLLHUP) {
base::test::TestFuture<void> signal;
forwarder_->SetStopQuitClosureForTesting(signal.GetCallback());
forwarder_->Start();
// Close the destination peer.
peer1_.reset();
EXPECT_TRUE(signal.Wait());
EXPECT_FALSE(forwarder_->IsRunning());
}
} // namespace net_base