blob: 547c2ae64bf0e2b68c5ce67079a28c9daa92143d [file] [log] [blame]
// Copyright 2020 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "patchpanel/socket_forwarder.h"
#include <netinet/in.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <memory>
#include <utility>
#include <vector>
#include <base/callback.h>
#include <base/run_loop.h>
#include <base/task/single_thread_task_executor.h>
#include <brillo/message_loops/base_message_loop.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
using testing::Each;
namespace patchpanel {
namespace {
// SocketForwarder reads blocks of 4096 bytes.
constexpr int kDataSize = 5000;
// Does a blocking read on |socket| until it receives |expected_byte_count|
// bytes which will be written into |buf|.
bool Read(Socket* socket, char* buf, int expected_byte_count) {
int read_byte_count = 0;
int bytes = 0;
while (read_byte_count < expected_byte_count) {
bytes = socket->RecvFrom(buf + read_byte_count, kDataSize);
if (bytes <= 0)
return false;
read_byte_count += bytes;
}
if (read_byte_count != expected_byte_count)
return false;
return true;
}
} // namespace
class SocketForwarderTest : public ::testing::Test {
void SetUp() override {
int fds0[2], fds1[2];
ASSERT_NE(-1, socketpair(AF_UNIX, SOCK_STREAM, 0 /* protocol */, fds0));
ASSERT_NE(-1, socketpair(AF_UNIX, SOCK_STREAM, 0 /* protocol */, fds1));
peer0_ = std::make_unique<Socket>(base::ScopedFD(fds0[0]));
peer1_ = std::make_unique<Socket>(base::ScopedFD(fds1[0]));
forwarder_ = std::make_unique<SocketForwarder>(
"test", std::make_unique<Socket>(base::ScopedFD(fds0[1])),
std::make_unique<Socket>(base::ScopedFD(fds1[1])));
}
protected:
std::unique_ptr<Socket> peer0_;
std::unique_ptr<Socket> peer1_;
// Forwards data betweeok |peer0_| and |peer1_|.
std::unique_ptr<SocketForwarder> forwarder_;
base::SingleThreadTaskExecutor task_executor_{base::MessagePumpType::IO};
brillo::BaseMessageLoop brillo_loop_{task_executor_.task_runner()};
};
TEST_F(SocketForwarderTest, ForwardDataAndClose) {
base::RunLoop loop;
forwarder_->SetStopQuitClosureForTesting(loop.QuitClosure());
forwarder_->Start();
std::vector<char> msg(kDataSize, 1);
EXPECT_EQ(peer0_->SendTo(msg.data(), msg.size()), kDataSize);
EXPECT_EQ(peer1_->SendTo(msg.data(), msg.size()), kDataSize);
// Close both sockets for writing.
EXPECT_NE(shutdown(peer0_->fd(), SHUT_WR), -1);
EXPECT_NE(shutdown(peer1_->fd(), SHUT_WR), -1);
loop.Run();
EXPECT_FALSE(forwarder_->IsRunning());
// Verify that all the data has been forwarded to the peers.
std::vector<char> expected_data_peer0(kDataSize);
std::vector<char> expected_data_peer1(kDataSize);
EXPECT_TRUE(Read(peer1_.get(), expected_data_peer1.data(), kDataSize));
EXPECT_TRUE(Read(peer0_.get(), expected_data_peer0.data(), kDataSize));
EXPECT_THAT(expected_data_peer0, Each(1));
EXPECT_THAT(expected_data_peer1, Each(1));
}
TEST_F(SocketForwarderTest, PeerSignalEPOLLHUP) {
base::RunLoop loop;
forwarder_->SetStopQuitClosureForTesting(loop.QuitClosure());
forwarder_->Start();
// Close the destination peer.
peer1_.reset();
loop.Run();
EXPECT_FALSE(forwarder_->IsRunning());
}
} // namespace patchpanel