| // Copyright 2018 The Chromium OS Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| #include "shill/net/netlink_sock_diag.h" |
| |
| #include <linux/inet_diag.h> |
| #include <linux/netlink.h> |
| #include <linux/sock_diag.h> |
| #include <sys/socket.h> |
| #include <sys/utsname.h> |
| |
| #include <string> |
| #include <utility> |
| |
| #include <base/logging.h> |
| #include <base/memory/ptr_util.h> |
| #include <base/numerics/safe_conversions.h> |
| #include <base/strings/string_number_conversions.h> |
| #include <base/strings/string_split.h> |
| |
| #include "shill/net/netlink_fd.h" |
| #include "shill/net/sockets.h" |
| |
| namespace { |
| |
| struct SockDiagRequest { |
| struct nlmsghdr header; |
| struct inet_diag_req_v2 req_opts; |
| }; |
| |
| SockDiagRequest CreateDumpRequest(uint8_t family, |
| uint8_t protocol, |
| int sequence_number) { |
| CHECK(family == AF_INET || family == AF_INET6) |
| << "Unsupported SOCK_DIAG family " << family; |
| |
| SockDiagRequest request; |
| request.header.nlmsg_len = sizeof(SockDiagRequest); |
| request.header.nlmsg_type = SOCK_DIAG_BY_FAMILY; |
| request.header.nlmsg_flags = NLM_F_REQUEST | NLM_F_DUMP; |
| request.header.nlmsg_seq = sequence_number; |
| request.req_opts.sdiag_family = family; |
| request.req_opts.sdiag_protocol = protocol; |
| request.req_opts.idiag_states = -1; // all states |
| return request; |
| } |
| |
| SockDiagRequest CreateDestroyRequest(uint8_t family, uint8_t protocol) { |
| SockDiagRequest request; |
| request.header.nlmsg_len = sizeof(SockDiagRequest); |
| request.header.nlmsg_type = SOCK_DESTROY; |
| request.header.nlmsg_flags = NLM_F_REQUEST; |
| request.req_opts.sdiag_family = family; |
| request.req_opts.sdiag_protocol = protocol; |
| request.req_opts.idiag_states = -1; // all states |
| return request; |
| } |
| |
| } // namespace |
| |
| namespace shill { |
| |
| NetlinkSockDiag::NetlinkSockDiag(std::unique_ptr<Sockets> sockets, |
| int file_descriptor) |
| : sockets_(std::move(sockets)), |
| file_descriptor_(file_descriptor), |
| sequence_number_(0) {} |
| |
| // static |
| std::unique_ptr<NetlinkSockDiag> NetlinkSockDiag::Create( |
| std::unique_ptr<Sockets> sockets) { |
| struct utsname un; |
| if (uname(&un) < 0) { |
| PLOG(ERROR) << "Could not check kernel version"; |
| return nullptr; |
| } |
| |
| unsigned major, minor; |
| std::vector<std::string> version = base::SplitString( |
| un.release, ".", base::TRIM_WHITESPACE, base::SPLIT_WANT_NONEMPTY); |
| if (version.size() < 2 || !base::StringToUint(version[0], &major) || |
| !base::StringToUint(version[1], &minor)) { |
| LOG(ERROR) << "Error parsing kernel version"; |
| return nullptr; |
| } |
| |
| // SOCK_DESTROY has only been backported to 3.14 and above. |
| if (major < 3 || (major == 3 && minor < 14)) { |
| LOG(ERROR) << "Kernel version " << major << "." << minor |
| << " does not support SOCK_DESTROY"; |
| return nullptr; |
| } |
| |
| int file_descriptor = |
| OpenNetlinkSocketFD(sockets.get(), NETLINK_SOCK_DIAG, 0); |
| if (file_descriptor == Sockets::kInvalidFileDescriptor) |
| return nullptr; |
| |
| VLOG(2) << "Netlink sock_diag socket started"; |
| return base::WrapUnique( |
| new NetlinkSockDiag(std::move(sockets), file_descriptor)); |
| } |
| |
| NetlinkSockDiag::~NetlinkSockDiag() { |
| sockets_->Close(file_descriptor_); |
| } |
| |
| bool NetlinkSockDiag::DestroySockets(uint8_t protocol, const IPAddress& saddr) { |
| uint8_t family; |
| if (saddr.family() == IPAddress::kFamilyIPv4) { |
| family = AF_INET; |
| } else if (saddr.family() == IPAddress::kFamilyIPv6) { |
| family = AF_INET6; |
| } else { |
| LOG(ERROR) << "Tried to destroy sockets for unsupported family"; |
| return false; |
| } |
| |
| std::vector<struct inet_diag_sockid> socks; |
| if (!GetSockets(family, protocol, &socks)) |
| return false; |
| |
| SockDiagRequest request = CreateDestroyRequest(family, protocol); |
| for (const auto& sockid : socks) { |
| VLOG(1) << "Destroying socket (" << family << ", " << protocol << ")"; |
| request.header.nlmsg_seq = ++sequence_number_; |
| request.req_opts.id = sockid; |
| if (!memcmp(sockid.idiag_src, saddr.GetConstData(), saddr.GetLength()) || |
| sockets_->Send(file_descriptor_, static_cast<void*>(&request), |
| sizeof(request), 0) < 0) { |
| PLOG(ERROR) << "Failed to write request to netlink socket"; |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| bool NetlinkSockDiag::GetSockets( |
| uint8_t family, |
| uint8_t protocol, |
| std::vector<struct inet_diag_sockid>* out_socks) { |
| CHECK(out_socks); |
| SockDiagRequest request = |
| CreateDumpRequest(family, protocol, ++sequence_number_); |
| if (sockets_->Send(file_descriptor_, static_cast<void*>(&request), |
| sizeof(request), 0) < 0) { |
| PLOG(ERROR) << "Failed to write sock_diag request to netlink socket " |
| << "(family: " << family << ", protocol: " << protocol << ")"; |
| return false; |
| } |
| |
| return ReadDumpContents(out_socks); |
| } |
| |
| bool NetlinkSockDiag::ReadDumpContents( |
| std::vector<struct inet_diag_sockid>* out_socks) { |
| char buf[8192]; |
| |
| out_socks->clear(); |
| |
| for (;;) { |
| ssize_t bytes_read = sockets_->RecvFrom(file_descriptor_, buf, sizeof(buf), |
| 0, nullptr, nullptr); |
| if (bytes_read < 0) { |
| PLOG(ERROR) << "Failed to read from netlink socket"; |
| return false; |
| } |
| |
| size_t unsigned_bytes = base::checked_cast<size_t>(bytes_read); |
| |
| for (nlmsghdr* nlh = reinterpret_cast<nlmsghdr*>(buf); |
| NLMSG_OK(nlh, unsigned_bytes); nlh = NLMSG_NEXT(nlh, unsigned_bytes)) { |
| switch (nlh->nlmsg_type) { |
| case NLMSG_DONE: |
| return true; |
| case NLMSG_ERROR: { |
| const nlmsgerr* err = |
| reinterpret_cast<const nlmsgerr*> NLMSG_DATA(nlh); |
| const char* err_msg = "Error parsing sock_diag netlink socket dump"; |
| if (nlh->nlmsg_len < NLMSG_LENGTH(sizeof(*err))) { |
| LOG(ERROR) << err_msg; |
| } else { |
| errno = -err->error; |
| PLOG(ERROR) << err_msg; |
| } |
| return false; |
| } |
| case SOCK_DIAG_BY_FAMILY: |
| struct inet_diag_msg current_msg; |
| memcpy(¤t_msg, NLMSG_DATA(nlh), sizeof(current_msg)); |
| out_socks->push_back(current_msg.id); |
| break; |
| default: |
| LOG(WARNING) << "Ignoring unexpected netlink message type " |
| << nlh->nlmsg_type; |
| break; |
| } |
| } |
| } |
| } |
| |
| } // namespace shill |