blob: 84c3290eb878543bd1dc30e15732d5d9080febce [file] [log] [blame] [edit]
// Copyright 2018 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "net-base/netlink_socket.h"
#include <linux/if_packet.h>
#include <linux/netlink.h>
#include <sys/select.h>
#include <sys/socket.h>
#include <memory>
#include <utility>
#include <base/logging.h>
#include <base/posix/eintr_wrapper.h>
#include "net-base/netlink_message.h"
// This is from a version of linux/socket.h that we don't have.
#define SOL_NETLINK 270
namespace net_base {
std::unique_ptr<NetlinkSocket> NetlinkSocket::Create() {
return CreateWithSocketFactory(std::make_unique<SocketFactory>());
}
std::unique_ptr<NetlinkSocket> NetlinkSocket::CreateWithSocketFactory(
std::unique_ptr<SocketFactory> socket_factory) {
std::unique_ptr<Socket> socket =
socket_factory->CreateNetlink(NETLINK_GENERIC, 0);
if (socket == nullptr) {
PLOG(ERROR) << "Failed to create AF_NETLINK socket";
return nullptr;
}
return std::unique_ptr<NetlinkSocket>(new NetlinkSocket(std::move(socket)));
}
NetlinkSocket::NetlinkSocket(std::unique_ptr<Socket> socket)
: socket_(std::move(socket)) {}
NetlinkSocket::~NetlinkSocket() = default;
bool NetlinkSocket::RecvMessage(std::vector<uint8_t>* message) {
return socket_->RecvMessage(message);
}
bool NetlinkSocket::SendMessage(base::span<const uint8_t> out_msg) {
const std::optional<size_t> result = socket_->Send(out_msg, 0);
if (!result) {
PLOG(ERROR) << "Send failed.";
return false;
}
if (*result != out_msg.size()) {
LOG(ERROR) << "Only sent " << *result << " bytes out of " << out_msg.size()
<< ".";
return false;
}
return true;
}
bool NetlinkSocket::SubscribeToEvents(uint32_t group_id) {
int err = setsockopt(socket_->Get(), SOL_NETLINK, NETLINK_ADD_MEMBERSHIP,
&group_id, sizeof(group_id));
if (err < 0) {
PLOG(ERROR) << "setsockopt didn't work.";
return false;
}
return true;
}
int NetlinkSocket::WaitForRead(base::TimeDelta timeout) const {
fd_set read_fds;
FD_ZERO(&read_fds);
if (socket_->Get() >= FD_SETSIZE) {
LOG(ERROR) << "Invalid file_descriptor: " << socket_->Get();
return -1;
}
FD_SET(socket_->Get(), &read_fds);
DCHECK(!timeout.is_negative());
struct timeval time = {
.tv_sec = static_cast<time_t>(timeout.InSeconds()),
.tv_usec = static_cast<suseconds_t>(
(timeout - base::Seconds(timeout.InSeconds())).InMicroseconds()),
};
return HANDLE_EINTR(
select(socket_->Get() + 1, &read_fds, nullptr, nullptr, &time));
}
uint32_t NetlinkSocket::GetSequenceNumber() {
if (++sequence_number_ == NetlinkMessage::kBroadcastSequenceNumber) {
++sequence_number_;
}
return sequence_number_;
}
} // namespace net_base.