blob: a473797a611c16a1b9eaea112b76e967746f9fb5 [file] [log] [blame]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "shill/net/netlink_sock_diag.h"
#include <linux/inet_diag.h>
#include <linux/netlink.h>
#include <linux/sock_diag.h>
#include <sys/socket.h>
#include <sys/utsname.h>
#include <string>
#include <utility>
#include <base/logging.h>
#include <base/memory/ptr_util.h>
#include <base/numerics/safe_conversions.h>
#include <base/strings/string_number_conversions.h>
#include <base/strings/string_split.h>
#include "shill/net/netlink_fd.h"
#include "shill/net/sockets.h"
namespace {
struct SockDiagRequest {
struct nlmsghdr header;
struct inet_diag_req_v2 req_opts;
};
SockDiagRequest CreateDumpRequest(uint8_t family,
uint8_t protocol,
int sequence_number) {
CHECK(family == AF_INET || family == AF_INET6)
<< "Unsupported SOCK_DIAG family " << family;
SockDiagRequest request;
request.header.nlmsg_len = sizeof(SockDiagRequest);
request.header.nlmsg_type = SOCK_DIAG_BY_FAMILY;
request.header.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP;
request.header.nlmsg_seq = sequence_number;
request.req_opts.sdiag_family = family;
request.req_opts.sdiag_protocol = protocol;
request.req_opts.idiag_states = -1; // all states
return request;
}
SockDiagRequest CreateDestroyRequest(uint8_t family, uint8_t protocol) {
SockDiagRequest request;
request.header.nlmsg_len = sizeof(SockDiagRequest);
request.header.nlmsg_type = SOCK_DESTROY;
request.header.nlmsg_flags = NLM_F_REQUEST;
request.req_opts.sdiag_family = family;
request.req_opts.sdiag_protocol = protocol;
request.req_opts.idiag_states = -1; // all states
return request;
}
} // namespace
namespace shill {
NetlinkSockDiag::NetlinkSockDiag(std::unique_ptr<Sockets> sockets,
int file_descriptor)
: sockets_(std::move(sockets)),
file_descriptor_(file_descriptor),
sequence_number_(0) {}
// static
std::unique_ptr<NetlinkSockDiag> NetlinkSockDiag::Create(
std::unique_ptr<Sockets> sockets) {
struct utsname un;
if (uname(&un) < 0) {
PLOG(ERROR) << "Could not check kernel version";
return nullptr;
}
unsigned major, minor;
std::vector<std::string> version = base::SplitString(
un.release, ".", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY);
if (version.size() < 2 || !base::StringToUint(version[0], &major) ||
!base::StringToUint(version[1], &minor)) {
LOG(ERROR) << "Error parsing kernel version";
return nullptr;
}
// SOCK_DESTROY has only been backported to 3.14 and above.
if (major < 3 || (major == 3 && minor < 14)) {
LOG(ERROR) << "Kernel version " << major << "." << minor
<< " does not support SOCK_DESTROY";
return nullptr;
}
int file_descriptor =
OpenNetlinkSocketFD(sockets.get(), NETLINK_SOCK_DIAG, 0);
if (file_descriptor == Sockets::kInvalidFileDescriptor)
return nullptr;
VLOG(2) << "Netlink sock_diag socket started";
return base::WrapUnique(
new NetlinkSockDiag(std::move(sockets), file_descriptor));
}
NetlinkSockDiag::~NetlinkSockDiag() {
sockets_->Close(file_descriptor_);
}
bool NetlinkSockDiag::DestroySockets(uint8_t protocol, const IPAddress& saddr) {
uint8_t family;
if (saddr.family() == IPAddress::kFamilyIPv4) {
family = AF_INET;
} else if (saddr.family() == IPAddress::kFamilyIPv6) {
family = AF_INET6;
} else {
LOG(ERROR) << "Tried to destroy sockets for unsupported family";
return false;
}
std::vector<struct inet_diag_sockid> socks;
if (!GetSockets(family, protocol, &socks))
return false;
SockDiagRequest request = CreateDestroyRequest(family, protocol);
for (const auto& sockid : socks) {
VLOG(1) << "Destroying socket (" << family << ", " << protocol << ")";
request.header.nlmsg_seq = ++sequence_number_;
request.req_opts.id = sockid;
if (!memcmp(sockid.idiag_src, saddr.GetConstData(), saddr.GetLength()) ||
sockets_->Send(file_descriptor_, static_cast<void*>(&request),
sizeof(request), 0) < 0) {
PLOG(ERROR) << "Failed to write request to netlink socket";
return false;
}
}
return true;
}
bool NetlinkSockDiag::GetSockets(
uint8_t family,
uint8_t protocol,
std::vector<struct inet_diag_sockid>* out_socks) {
CHECK(out_socks);
SockDiagRequest request =
CreateDumpRequest(family, protocol, ++sequence_number_);
if (sockets_->Send(file_descriptor_, static_cast<void*>(&request),
sizeof(request), 0) < 0) {
PLOG(ERROR) << "Failed to write sock_diag request to netlink socket "
<< "(family: " << family << ", protocol: " << protocol << ")";
return false;
}
return ReadDumpContents(out_socks);
}
bool NetlinkSockDiag::ReadDumpContents(
std::vector<struct inet_diag_sockid>* out_socks) {
char buf[8192];
out_socks->clear();
for (;;) {
ssize_t bytes_read = sockets_->RecvFrom(file_descriptor_, buf, sizeof(buf),
0, nullptr, nullptr);
if (bytes_read < 0) {
PLOG(ERROR) << "Failed to read from netlink socket";
return false;
}
size_t unsigned_bytes = base::checked_cast<size_t>(bytes_read);
for (nlmsghdr* nlh = reinterpret_cast<nlmsghdr*>(buf);
NLMSG_OK(nlh, unsigned_bytes); nlh = NLMSG_NEXT(nlh, unsigned_bytes)) {
switch (nlh->nlmsg_type) {
case NLMSG_DONE:
return true;
case NLMSG_ERROR: {
const nlmsgerr* err =
reinterpret_cast<const nlmsgerr*> NLMSG_DATA(nlh);
const char* err_msg = "Error parsing sock_diag netlink socket dump";
if (nlh->nlmsg_len < NLMSG_LENGTH(sizeof(*err))) {
LOG(ERROR) << err_msg;
} else {
errno = -err->error;
PLOG(ERROR) << err_msg;
}
return false;
}
case SOCK_DIAG_BY_FAMILY:
struct inet_diag_msg current_msg;
memcpy(&current_msg, NLMSG_DATA(nlh), sizeof(current_msg));
out_socks->push_back(current_msg.id);
break;
default:
LOG(WARNING) << "Ignoring unexpected netlink message type "
<< nlh->nlmsg_type;
break;
}
}
}
}
} // namespace shill