blob: 1c84a42df52e767684747ec6f61120e81bf33be4 [file] [log] [blame] [edit]
// Copyright 2023 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef NET_BASE_SOCKET_H_
#define NET_BASE_SOCKET_H_
#include <sys/socket.h>
#include <memory>
#include <optional>
#include <vector>
#include <base/containers/span.h>
#include <base/files/file_descriptor_watcher_posix.h>
#include <base/files/scoped_file.h>
#include <base/functional/callback.h>
#include <brillo/brillo_export.h>
namespace net_base {
// Represents a socket file descriptor, and provides the encapsulation for
// the standard POSIX and Linux socket operations.
class BRILLO_EXPORT Socket {
public:
// Creates the socket instance by socket(...) method. On failure, returns
// nullptr and the errno is set. The caller should use PLOG() to print errno.
static std::unique_ptr<Socket> Create(int domain, int type, int protocol = 0);
// Creates the socket instance with the socket descriptor. Returns nullptr if
// |fd| is invalid.
static std::unique_ptr<Socket> CreateFromFd(base::ScopedFD fd);
virtual ~Socket();
Socket(const Socket&) = delete;
Socket& operator=(const Socket&) = delete;
// Returns the raw socket file descriptor.
int Get() const;
// Sets/unsets the callback which will be called when the socket is ready to
// be read.
void SetReadableCallback(base::RepeatingClosure callback);
void UnsetReadableCallback();
// Releases and returns the socket file descriptor, allowing the socket to
// remain open as the Socket is destroyed. After the call, |socket| will be
// dropped, so the user cannot hold a Socket instance with invalid file
// descriptor.
[[nodiscard]] static int Release(std::unique_ptr<Socket> socket);
// Delegates to accept(fd_.get(), ...). Returns the new connected socket.
virtual std::unique_ptr<Socket> Accept(struct sockaddr* addr,
socklen_t* addrlen) const;
// Delegates to bind(fd_.get(), ...). Returns true if successful.
// On failure, the errno is set. The caller should use PLOG() to print errno.
virtual bool Bind(const struct sockaddr* addr, socklen_t addrlen) const;
// Delegates to connect(fd_.get(), ...). Returns true if successful.
// On failure, the errno is set. The caller should use PLOG() to print errno.
virtual bool Connect(const struct sockaddr* addr, socklen_t addrlen) const;
// Delegates to getsockname(fd_.get(), ...). Returns true if successful.
// On failure, the errno is set. The caller should use PLOG() to print errno.
virtual bool GetSockName(struct sockaddr* addr, socklen_t* addrlen) const;
// Delegates to listen(fd_.get(), ...). Returns true if successful.
// On failure, the errno is set. The caller should use PLOG() to print errno.
virtual bool Listen(int backlog) const;
// Delegates to ioctl(fd_.get(), ...). Returns the returned value of ioctl()
// if successful. On failure, returns std::nullopt and the errno is set. The
// caller should use PLOG() to print errno.
// NOLINTNEXTLINE(runtime/int)
virtual std::optional<int> Ioctl(unsigned long request, void* argp) const;
// Delegates to recvfrom(fd_.get(), ...). On success, returns the length of
// the received message in bytes. On failure, returns std::nullopt and the
// errno is set. The caller should use PLOG() to print errno.
virtual std::optional<size_t> RecvFrom(base::span<char> buf,
int flags = 0,
struct sockaddr* src_addr = nullptr,
socklen_t* addrlen = nullptr) const;
virtual std::optional<size_t> RecvFrom(base::span<uint8_t> buf,
int flags = 0,
struct sockaddr* src_addr = nullptr,
socklen_t* addrlen = nullptr) const;
// Reads data from the socket into |message| and returns true if successful.
// The |message| parameter will be resized to hold the entirety of the read
// message (and any data in |message| will be overwritten). If the socket
// is stream-oriented, this will read all available data.
virtual bool RecvMessage(std::vector<uint8_t>* message) const;
// Delegates to send(fd_.get(), ...). On success, returns the number of bytes
// sent. On failure, returns std::nullopt and the errno is set. The caller
// should use PLOG() to print errno.
virtual std::optional<size_t> Send(base::span<const char> buf,
int flags = MSG_NOSIGNAL) const;
virtual std::optional<size_t> Send(base::span<const uint8_t> buf,
int flags = MSG_NOSIGNAL) const;
// Delegates to sendto(fd_.get(), ...). On success, returns the number of
// bytes sent. On failure, returns std::nullopt and the errno is set. The
// caller should use PLOG() to print errno.
virtual std::optional<size_t> SendTo(base::span<const char> buf,
int flags,
const struct sockaddr* dest_addr,
socklen_t addrlen) const;
virtual std::optional<size_t> SendTo(base::span<const uint8_t> buf,
int flags,
const struct sockaddr* dest_addr,
socklen_t addrlen) const;
// Delegates to setsockopt(fd_.get(), ...). Returns true if successful.
// On failure, the errno is set. The caller should use PLOG() to print errno.
virtual bool SetSockOpt(int level,
int optname,
base::span<const uint8_t> opt_bytes) const;
// Set the size of receiver buffer in bytes for the socket file descriptor.
// Returns true if successful.
// On failure, the errno is set. The caller should use PLOG() to print errno.
virtual bool SetReceiveBuffer(int size) const;
protected:
Socket(base::ScopedFD fd, int socket_type);
// The socket file descriptor. It's always valid during the lifetime of the
// Socket instance.
base::ScopedFD fd_;
// Type of the socket. This is fetched once by the static factory function
// if an existing socket is passed in.
int socket_type_;
// The read watcher of the |fd_|. It should be destroyed before |fd_| is
// closed, so it's declared after |fd_|.
std::unique_ptr<base::FileDescriptorWatcher::Controller> watcher_;
private:
// Helper to perform RecvMessage on SOCK_STREAM-type sockets.
bool RecvStream(std::vector<uint8_t>* message) const;
};
// Creates the Socket instance. It's used for injecting MockSocketFactory at
// testing to create the mock Socket instance.
class BRILLO_EXPORT SocketFactory {
public:
// Keep this large enough to avoid overflows on IPv6 SNM routing update
// spikes.
static constexpr int kNetlinkReceiveBufferSize = 512 * 1024;
SocketFactory() = default;
virtual ~SocketFactory() = default;
// Creates the socket instance by the Socket::Create() method.
// On failure, returns nullptr and the errno is set. The caller should use
// PLOG() to print errno.
virtual std::unique_ptr<Socket> Create(int domain,
int type,
int protocol = 0);
// Creates the socket instance and binds to netlink. Sets the received buffer
// size to |receive_buffer_size| if it is set.
// Returns nullptr on failure.
//
// Note: setting the received buffer size above rmem_max requires
// CAP_NET_ADMIN.
virtual std::unique_ptr<Socket> CreateNetlink(
int netlink_family,
uint32_t netlink_groups_mask,
std::optional<int> receive_buffer_size = kNetlinkReceiveBufferSize);
};
BRILLO_EXPORT std::ostream& operator<<(std::ostream& stream,
const Socket& socket);
} // namespace net_base
#endif // NET_BASE_SOCKET_H_