blob: 95c4a31284541bfa7abf472441e559a1f6fd0780 [file] [log] [blame] [edit]
// Copyright 2018 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "shill/network/icmp_session.h"
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>
#include <arpa/inet.h>
#include <net/if.h>
#include <netinet/icmp6.h>
#include <netinet/ip.h>
#include <base/check_op.h>
#include <base/containers/span.h>
#include <base/files/file_util.h>
#include <base/logging.h>
#include <base/time/time.h>
#include <net-base/byte_utils.h>
#include <net-base/socket.h>
#include "shill/event_dispatcher.h"
#include "shill/logging.h"
namespace {
const int kIPHeaderLengthUnitBytes = 4;
} // namespace
namespace shill {
namespace Logging {
static auto kModuleLogScope = ScopeLogger::kWiFi;
} // namespace Logging
uint16_t IcmpSession::kNextUniqueEchoId = 0;
std::unique_ptr<IcmpSession> IcmpSession::CreateForTesting(
EventDispatcher* dispatcher,
std::unique_ptr<net_base::SocketFactory> socket_factory,
int echo_id) {
auto icmp_session =
std::make_unique<IcmpSession>(dispatcher, std::move(socket_factory));
icmp_session->echo_id_ = echo_id;
return icmp_session;
}
IcmpSession::IcmpSession(
EventDispatcher* dispatcher,
std::unique_ptr<net_base::SocketFactory> socket_factory)
: dispatcher_(dispatcher),
socket_factory_(std::move(socket_factory)),
echo_id_(kNextUniqueEchoId) {
// Each IcmpSession will have a unique echo ID to identify requests and reply
// messages.
++kNextUniqueEchoId;
}
IcmpSession::~IcmpSession() {
Stop();
}
bool IcmpSession::Start(const net_base::IPAddress& destination,
int interface_index,
std::string_view interface_name,
IcmpSessionResultCallback result_callback) {
if (!dispatcher_) {
LOG(ERROR) << "Invalid dispatcher";
return false;
}
if (IsStarted()) {
LOG(WARNING) << "ICMP session already started";
return false;
}
std::unique_ptr<net_base::Socket> socket;
switch (destination.GetFamily()) {
case net_base::IPFamily::kIPv4:
socket = socket_factory_->Create(AF_INET, SOCK_RAW | SOCK_CLOEXEC,
IPPROTO_ICMP);
break;
case net_base::IPFamily::kIPv6:
socket = socket_factory_->Create(AF_INET6, SOCK_RAW | SOCK_CLOEXEC,
IPPROTO_ICMPV6);
break;
}
if (socket == nullptr) {
PLOG(ERROR) << "Could not create ICMP socket";
return false;
}
if (!base::SetNonBlocking(socket->Get())) {
PLOG(ERROR) << "Could not set socket to be non-blocking";
return false;
}
if (interface_name.size() >= IFNAMSIZ) {
LOG(ERROR) << "The interface name '" << interface_name << "' is too long";
return false;
}
struct ifreq ifr;
memset(&ifr, 0, sizeof(ifr));
memcpy(ifr.ifr_name, interface_name.data(), interface_name.size());
if (!socket->SetSockOpt(SOL_SOCKET, SO_BINDTODEVICE,
net_base::byte_utils::AsBytes(ifr))) {
PLOG(ERROR) << "Failed to bind socket on " << interface_name;
return false;
}
socket_ = std::move(socket);
destination_ = destination;
interface_index_ = interface_index;
result_callback_ = std::move(result_callback);
socket_->SetReadableCallback(base::BindRepeating(&IcmpSession::OnIcmpReadable,
base::Unretained(this)));
timeout_callback_.Reset(BindOnce(&IcmpSession::ReportResultAndStopSession,
weak_ptr_factory_.GetWeakPtr()));
dispatcher_->PostDelayedTask(FROM_HERE, timeout_callback_.callback(),
kTimeout);
seq_num_to_sent_recv_time_.clear();
received_echo_reply_seq_numbers_.clear();
dispatcher_->PostTask(FROM_HERE,
base::BindOnce(&IcmpSession::TransmitEchoRequestTask,
weak_ptr_factory_.GetWeakPtr()));
return true;
}
void IcmpSession::Stop() {
if (!IsStarted()) {
return;
}
timeout_callback_.Cancel();
socket_ = nullptr;
}
bool IcmpSession::IsStarted() const {
return socket_ != nullptr;
}
// static
bool IcmpSession::AnyRepliesReceived(const IcmpSessionResult& result) {
for (const base::TimeDelta& latency : result) {
if (!latency.is_zero()) {
return true;
}
}
return false;
}
// static
bool IcmpSession::IsPacketLossPercentageGreaterThan(
const IcmpSessionResult& result, int percentage_threshold) {
if (percentage_threshold < 0) {
LOG(ERROR) << __func__ << ": negative percentage threshold ("
<< percentage_threshold << ")";
return false;
}
if (result.empty()) {
return false;
}
int lost_packet_count = 0;
for (const base::TimeDelta& latency : result) {
if (latency.is_zero()) {
++lost_packet_count;
}
}
int packet_loss_percentage = (lost_packet_count * 100) / result.size();
return packet_loss_percentage > percentage_threshold;
}
// static
uint16_t IcmpSession::ComputeIcmpChecksum(const struct icmphdr& hdr,
size_t len) {
// Compute Internet Checksum for "len" bytes beginning at location "hdr".
// Adapted directly from the canonical implementation in RFC 1071 Section 4.1.
uint32_t sum = 0;
const uint16_t* addr = reinterpret_cast<const uint16_t*>(&hdr);
while (len > 1) {
sum += *addr;
++addr;
len -= sizeof(*addr);
}
// Add left-over byte, if any.
if (len > 0) {
sum += *reinterpret_cast<const uint8_t*>(addr);
}
// Fold 32-bit sum to 16 bits.
while (sum >> 16) {
sum = (sum & 0xffff) + (sum >> 16);
}
return static_cast<uint16_t>(~sum);
}
void IcmpSession::TransmitEchoRequestTask() {
if (!IsStarted()) {
// This might happen when ping times out or is stopped between two calls
// to IcmpSession::TransmitEchoRequestTask.
return;
}
DCHECK(destination_);
const bool success =
(destination_->GetFamily() == net_base::IPFamily::kIPv4)
? TransmitV4EchoRequest(*destination_->ToIPv4Address())
: TransmitV6EchoRequest(*destination_->ToIPv6Address());
if (success) {
seq_num_to_sent_recv_time_[current_sequence_number_] =
std::make_pair(base::TimeTicks::Now(), base::TimeTicks());
}
++current_sequence_number_;
// If we fail to transmit the echo request, fall through instead of returning,
// so we continue sending echo requests until |kTotalNumEchoRequests| echo
// requests are sent.
if (seq_num_to_sent_recv_time_.size() != kTotalNumEchoRequests) {
dispatcher_->PostDelayedTask(
FROM_HERE,
base::BindOnce(&IcmpSession::TransmitEchoRequestTask,
weak_ptr_factory_.GetWeakPtr()),
kEchoRequestInterval);
}
}
bool IcmpSession::TransmitV4EchoRequest(const net_base::IPv4Address& address) {
struct icmphdr icmp_header;
memset(&icmp_header, 0, sizeof(icmp_header));
icmp_header.type = ICMP_ECHO;
icmp_header.code = kIcmpEchoCode;
icmp_header.un.echo.id = echo_id_;
icmp_header.un.echo.sequence = current_sequence_number_;
icmp_header.checksum = ComputeIcmpChecksum(icmp_header, sizeof(icmp_header));
const base::span<const uint8_t> payload = {
reinterpret_cast<const uint8_t*>(&icmp_header), sizeof(icmp_header)};
struct sockaddr_in destination_address;
destination_address.sin_family = AF_INET;
destination_address.sin_addr = address.ToInAddr();
const std::optional<size_t> result = socket_->SendTo(
payload, 0, reinterpret_cast<struct sockaddr*>(&destination_address),
sizeof(destination_address));
if (!result) {
PLOG(ERROR) << "Socket sendto failed";
} else if (result < payload.size()) {
LOG(ERROR) << "Socket sendto returned " << *result
<< " which is less than the expected result " << payload.size();
}
return result == payload.size();
}
bool IcmpSession::TransmitV6EchoRequest(const net_base::IPv6Address& address) {
struct icmp6_hdr icmp_header;
memset(&icmp_header, 0, sizeof(icmp_header));
icmp_header.icmp6_type = ICMP6_ECHO_REQUEST;
icmp_header.icmp6_code = kIcmpEchoCode;
icmp_header.icmp6_id = echo_id_;
icmp_header.icmp6_seq = current_sequence_number_;
const base::span<const uint8_t> payload = {
reinterpret_cast<const uint8_t*>(&icmp_header), sizeof(icmp_header)};
// icmp6_cksum is filled in by the kernel for IPPROTO_ICMPV6 sockets
// (RFC3542 section 3.1)
struct sockaddr_in6 destination_address;
memset(&destination_address, 0, sizeof(destination_address));
destination_address.sin6_family = AF_INET6;
destination_address.sin6_scope_id = interface_index_;
destination_address.sin6_addr = address.ToIn6Addr();
const std::optional<size_t> result = socket_->SendTo(
payload, 0, reinterpret_cast<struct sockaddr*>(&destination_address),
sizeof(destination_address));
if (!result) {
PLOG(ERROR) << "Socket sendto failed";
} else if (result < payload.size()) {
LOG(ERROR) << "Socket sendto returned " << *result
<< " which is less than the expected result " << payload.size();
}
return result == payload.size();
}
int IcmpSession::OnV4EchoReplyReceived(base::span<const uint8_t> message) {
if (message.size() < sizeof(struct iphdr)) {
LOG(WARNING) << "Received ICMP packet is too short to contain IP header";
return -1;
}
const struct iphdr* received_ip_header =
reinterpret_cast<const struct iphdr*>(message.data());
if (message.size() < received_ip_header->ihl * kIPHeaderLengthUnitBytes +
sizeof(struct icmphdr)) {
LOG(WARNING) << "Received ICMP packet is too short to contain ICMP header";
return -1;
}
const struct icmphdr* received_icmp_header =
reinterpret_cast<const struct icmphdr*>(
message.data() + received_ip_header->ihl * kIPHeaderLengthUnitBytes);
// We might have received other types of ICMP traffic, so ensure that the
// message is an echo reply before handling it.
if (received_icmp_header->type != ICMP_ECHOREPLY) {
return -1;
}
// Make sure the message is valid and matches a pending echo request.
if (received_icmp_header->code != kIcmpEchoCode) {
LOG(WARNING) << "ICMP header code is invalid";
return -1;
}
if (received_icmp_header->un.echo.id != echo_id_) {
SLOG(2) << "received message echo id (" << received_icmp_header->un.echo.id
<< ") does not match this ICMP session's echo id (" << echo_id_
<< ")";
return -1;
}
return received_icmp_header->un.echo.sequence;
}
int IcmpSession::OnV6EchoReplyReceived(base::span<const uint8_t> message) {
if (message.size() < sizeof(struct icmp6_hdr)) {
LOG(WARNING)
<< "Received ICMP packet is too short to contain ICMPv6 header";
return -1;
}
// Per RFC3542 section 3, ICMPv6 raw sockets do not contain the IP header
// (unlike ICMPv4 raw sockets).
const struct icmp6_hdr* received_icmp_header =
reinterpret_cast<const struct icmp6_hdr*>(message.data());
// We might have received other types of ICMP traffic, so ensure that the
// message is an echo reply before handling it.
if (received_icmp_header->icmp6_type != ICMP6_ECHO_REPLY) {
return -1;
}
// Make sure the message is valid and matches a pending echo request.
if (received_icmp_header->icmp6_code != kIcmpEchoCode) {
LOG(WARNING) << "ICMPv6 header code is invalid";
return -1;
}
if (received_icmp_header->icmp6_id != echo_id_) {
SLOG(2) << "received message echo id (" << received_icmp_header->icmp6_id
<< ") does not match this ICMPv6 session's echo id (" << echo_id_
<< ")";
return -1;
}
return received_icmp_header->icmp6_seq;
}
void IcmpSession::OnIcmpReadable() {
std::vector<uint8_t> message;
if (socket_->RecvMessage(&message)) {
OnEchoReplyReceived(message);
} else {
PLOG(ERROR) << __func__ << ": failed to receive message from socket";
// Do nothing when we encounter an IO error, so we can continue receiving
// other pending echo replies.
}
}
void IcmpSession::OnEchoReplyReceived(base::span<const uint8_t> message) {
if (!destination_) {
LOG(WARNING) << "Failed to get ICMP destination";
return;
}
int received_seq_num = -1;
switch (destination_->GetFamily()) {
case net_base::IPFamily::kIPv4:
received_seq_num = OnV4EchoReplyReceived(message);
break;
case net_base::IPFamily::kIPv6:
received_seq_num = OnV6EchoReplyReceived(message);
break;
}
if (received_seq_num < 0) {
// Could not parse reply.
return;
}
if (received_echo_reply_seq_numbers_.find(received_seq_num) !=
received_echo_reply_seq_numbers_.end()) {
// Echo reply for this message already handled previously.
return;
}
const auto& seq_num_to_sent_recv_time_pair =
seq_num_to_sent_recv_time_.find(received_seq_num);
if (seq_num_to_sent_recv_time_pair == seq_num_to_sent_recv_time_.end()) {
// Echo reply not meant for any sent echo requests.
return;
}
// Record the time that the echo reply was received.
seq_num_to_sent_recv_time_pair->second.second = base::TimeTicks::Now();
received_echo_reply_seq_numbers_.insert(received_seq_num);
if (received_echo_reply_seq_numbers_.size() == kTotalNumEchoRequests) {
// All requests sent and replies received, so report results and end the
// ICMP session.
ReportResultAndStopSession();
}
}
std::vector<base::TimeDelta> IcmpSession::GenerateIcmpResult() {
std::vector<base::TimeDelta> latencies;
for (const auto& seq_num_to_sent_recv_time_pair :
seq_num_to_sent_recv_time_) {
const SentRecvTimePair& sent_recv_timestamp_pair =
seq_num_to_sent_recv_time_pair.second;
if (sent_recv_timestamp_pair.second.is_null()) {
// Invalid latency if an echo response has not been received.
latencies.push_back(base::TimeDelta());
} else {
latencies.push_back(sent_recv_timestamp_pair.second -
sent_recv_timestamp_pair.first);
}
}
return latencies;
}
void IcmpSession::ReportResultAndStopSession() {
if (!IsStarted()) {
LOG(WARNING) << "ICMP session not started";
return;
}
Stop();
// Invoke result callback after calling IcmpSession::Stop, since the callback
// might delete this object. (Any subsequent call to IcmpSession::Stop leads
// to a segfault since this function belongs to the deleted object.)
if (!result_callback_.is_null()) {
std::move(result_callback_).Run(GenerateIcmpResult());
}
}
} // namespace shill