blob: 50cecc77b138eb33833f263d593b229e67778848 [file] [log] [blame]
// Copyright 2023 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net-base/socket.h"
#include <fcntl.h>
#include <linux/netlink.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <memory>
#include <optional>
#include <utility>
#include <base/check_op.h>
#include <base/functional/bind.h>
#include <base/functional/callback.h>
#include <base/logging.h>
#include <base/posix/eintr_wrapper.h>
namespace net_base {
namespace {
std::optional<size_t> ToOptionalSizeT(ssize_t size) {
if (size < 0) {
return std::nullopt;
}
return static_cast<size_t>(size);
}
} // namespace
// static
std::unique_ptr<Socket> Socket::CreateFromFd(base::ScopedFD fd) {
if (!fd.is_valid()) {
return nullptr;
}
return std::unique_ptr<Socket>(new Socket(std::move(fd)));
}
Socket::Socket(base::ScopedFD fd) : fd_(std::move(fd)) {
LOG_IF(FATAL, !fd_.is_valid()) << "the socket fd is invalid";
}
Socket::~Socket() = default;
int Socket::Get() const {
return fd_.get();
}
// static
int Socket::Release(std::unique_ptr<Socket> socket) {
if (socket == nullptr) {
return -1;
}
return socket->fd_.release();
}
// Some system calls can be interrupted and return EINTR, but will succeed on
// retry. The HANDLE_EINTR macro retries a call if it returns EINTR. For a
// list of system calls that can return EINTR, see 'man 7 signal' under the
// heading "Interruption of System Calls and Library Functions by Signal
// Handlers".
std::unique_ptr<Socket> Socket::Accept(struct sockaddr* addr,
socklen_t* addrlen) const {
return CreateFromFd(
base::ScopedFD(HANDLE_EINTR(accept(fd_.get(), addr, addrlen))));
}
bool Socket::Bind(const struct sockaddr* addr, socklen_t addrlen) const {
return bind(fd_.get(), addr, addrlen) == 0;
}
bool Socket::GetSockName(struct sockaddr* addr, socklen_t* addrlen) const {
return getsockname(fd_.get(), addr, addrlen) == 0;
}
bool Socket::Listen(int backlog) const {
return listen(fd_.get(), backlog) == 0;
}
// NOLINTNEXTLINE(runtime/int)
std::optional<int> Socket::Ioctl(unsigned long request, void* argp) const {
int res = HANDLE_EINTR(ioctl(fd_.get(), request, argp));
if (res < 0) {
return std::nullopt;
}
return res;
}
std::optional<size_t> Socket::RecvFrom(base::span<uint8_t> buf,
int flags,
struct sockaddr* src_addr,
socklen_t* addrlen) const {
ssize_t res = HANDLE_EINTR(
recvfrom(fd_.get(), buf.data(), buf.size(), flags, src_addr, addrlen));
return ToOptionalSizeT(res);
}
bool Socket::RecvMessage(std::vector<uint8_t>* message) const {
DCHECK(message) << "message is null";
// Determine the amount of data currently waiting.
const size_t kFakeReadByteCount = 1;
std::vector<uint8_t> fake_read(kFakeReadByteCount);
const std::optional<size_t> read_size =
RecvFrom(fake_read, MSG_TRUNC | MSG_PEEK, nullptr, nullptr);
if (!read_size.has_value()) {
return false;
}
// Read the data that was waiting when we did our previous read.
message->resize(*read_size, 0);
return RecvFrom(*message, 0, nullptr, nullptr) == *read_size;
}
std::optional<size_t> Socket::Send(base::span<const uint8_t> buf,
int flags) const {
ssize_t res = HANDLE_EINTR(send(fd_.get(), buf.data(), buf.size(), flags));
return ToOptionalSizeT(res);
}
std::optional<size_t> Socket::SendTo(base::span<const uint8_t> buf,
int flags,
const struct sockaddr* dest_addr,
socklen_t addrlen) const {
ssize_t res = HANDLE_EINTR(
sendto(fd_.get(), buf.data(), buf.size(), flags, dest_addr, addrlen));
return ToOptionalSizeT(res);
}
bool Socket::SetNonBlocking() const {
return HANDLE_EINTR(fcntl(fd_.get(), F_SETFL,
fcntl(fd_.get(), F_GETFL) | O_NONBLOCK)) == 0;
}
bool Socket::SetReceiveBuffer(int size) const {
// Note: kernel will set buffer to 2*size to allow for struct skbuff overhead
return setsockopt(fd_.get(), SOL_SOCKET, SO_RCVBUFFORCE, &size,
sizeof(size)) == 0;
}
std::unique_ptr<Socket> SocketFactory::Create(int domain,
int type,
int protocol) {
return Socket::CreateFromFd(base::ScopedFD(socket(domain, type, protocol)));
}
std::unique_ptr<Socket> SocketFactory::CreateNetlink(
int netlink_family,
uint32_t netlink_groups_mask,
std::optional<int> receive_buffer_size) {
auto socket = Create(PF_NETLINK, SOCK_DGRAM | SOCK_CLOEXEC, netlink_family);
if (!socket) {
PLOG(ERROR) << "Failed to open netlink socket for family "
<< netlink_family;
return nullptr;
}
if (receive_buffer_size) {
if (!socket->SetReceiveBuffer(*receive_buffer_size)) {
PLOG(WARNING) << "Failed to increase receive buffer size to "
<< SocketFactory::kNetlinkReceiveBufferSize << "b";
}
}
struct sockaddr_nl addr;
memset(&addr, 0, sizeof(addr));
addr.nl_family = AF_NETLINK;
addr.nl_groups = netlink_groups_mask;
if (!socket->Bind(reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr))) {
PLOG(ERROR) << "Netlink socket bind failed for family " << netlink_family;
return nullptr;
}
return socket;
}
} // namespace net_base