blob: dfad84b3001456260858e8557cba2721cf9f6fa3 [file] [log] [blame]
// Copyright 2020 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "patchpanel/broadcast_forwarder.h"
#include <arpa/inet.h>
#include <errno.h>
#include <linux/filter.h>
#include <linux/if_ether.h>
#include <linux/if_packet.h>
#include <linux/rtnetlink.h>
#include <net/if.h>
#include <netinet/ip.h>
#include <netinet/udp.h>
#include <shill/net/rtnl_handler.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <utility>
#include <base/bind.h>
#include <base/logging.h>
#include "patchpanel/socket.h"
namespace {
constexpr int kBufSize = 4096;
constexpr uint16_t kIpFragOffsetMask = 0x1FFF;
// Broadcast forwarder will not forward system ports (0 - 1023).
constexpr uint16_t kMinValidPort = 1024;
// SetBcastSockFilter filters out packets by only accepting (all conditions
// must be fulfilled):
// - UDP protocol,
// - Destination address equals to 255.255.255.255 or |bcast_addr|,
// - Source and destination port is not a system port (>= 1024).
bool SetBcastSockFilter(int fd, uint32_t bcast_addr) {
sock_filter kBcastFwdBpfInstructions[] = {
// Load IP protocol value.
BPF_STMT(BPF_LD | BPF_B | BPF_ABS, offsetof(iphdr, protocol)),
// Check if equals UDP, if not, then goto return 0.
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, IPPROTO_UDP, 0, 8),
// Load IP destination address.
BPF_STMT(BPF_LD | BPF_W | BPF_IND, offsetof(iphdr, daddr)),
// Check if it is a broadcast address.
// All 1s.
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, patchpanel::kBcastAddr, 1, 0),
// Current interface broadcast address.
BPF_JUMP(BPF_JMP | BPF_JEQ | BPF_K, htonl(bcast_addr), 0, 5),
// Move index to start of UDP header.
BPF_STMT(BPF_LDX | BPF_IMM, sizeof(iphdr)),
// Load UDP source port.
BPF_STMT(BPF_LD | BPF_H | BPF_IND, offsetof(udphdr, uh_sport)),
// Check if it is a valid source port (>= 1024).
BPF_JUMP(BPF_JMP | BPF_JGE | BPF_K, kMinValidPort, 0, 2),
// Load UDP destination port.
BPF_STMT(BPF_LD | BPF_H | BPF_IND, offsetof(udphdr, uh_dport)),
// Check if it is a valid destination port (>= 1024).
BPF_JUMP(BPF_JMP | BPF_JGE | BPF_K, kMinValidPort, 1, 0),
// Return 0.
BPF_STMT(BPF_RET | BPF_K, 0),
// Return MAX.
BPF_STMT(BPF_RET | BPF_K, IP_MAXPACKET),
};
sock_fprog kBcastFwdBpfProgram = {
.len = sizeof(kBcastFwdBpfInstructions) / sizeof(sock_filter),
.filter = kBcastFwdBpfInstructions};
if (setsockopt(fd, SOL_SOCKET, SO_ATTACH_FILTER, &kBcastFwdBpfProgram,
sizeof(kBcastFwdBpfProgram)) != 0) {
PLOG(ERROR)
<< "setsockopt(SO_ATTACH_FILTER) failed for broadcast forwarder";
return false;
}
return true;
}
void Ioctl(int fd,
const std::string& ifname,
unsigned int cmd,
struct ifreq* ifr) {
if (ifname.empty()) {
LOG(WARNING) << "Empty interface name";
return;
}
memset(ifr, 0, sizeof(struct ifreq));
strncpy(ifr->ifr_name, ifname.c_str(), IFNAMSIZ);
if (ioctl(fd, cmd, ifr) < 0) {
// Ignore EADDRNOTAVAIL: IPv4 was not provisioned.
if (errno != EADDRNOTAVAIL) {
PLOG(ERROR) << "ioctl call failed for " << ifname;
}
}
}
uint32_t GetIfreqAddr(const struct ifreq& ifr) {
return reinterpret_cast<const struct sockaddr_in*>(&ifr.ifr_addr)
->sin_addr.s_addr;
}
uint32_t GetIfreqBroadaddr(const struct ifreq& ifr) {
return reinterpret_cast<const struct sockaddr_in*>(&ifr.ifr_broadaddr)
->sin_addr.s_addr;
}
uint32_t GetIfreqNetmask(const struct ifreq& ifr) {
return reinterpret_cast<const struct sockaddr_in*>(&ifr.ifr_netmask)
->sin_addr.s_addr;
}
} // namespace
namespace patchpanel {
BroadcastForwarder::Socket::Socket(base::ScopedFD fd,
const base::Callback<void(int)>& callback,
uint32_t addr,
uint32_t broadaddr,
uint32_t netmask)
: fd(std::move(fd)), addr(addr), broadaddr(broadaddr), netmask(netmask) {
watcher = base::FileDescriptorWatcher::WatchReadable(
Socket::fd.get(), base::BindRepeating(callback, Socket::fd.get()));
}
BroadcastForwarder::BroadcastForwarder(const std::string& dev_ifname)
: dev_ifname_(dev_ifname) {
addr_listener_ = std::make_unique<shill::RTNLListener>(
shill::RTNLHandler::kRequestAddr,
base::Bind(&BroadcastForwarder::AddrMsgHandler,
weak_factory_.GetWeakPtr()));
shill::RTNLHandler::GetInstance()->Start(RTMGRP_IPV4_IFADDR);
}
void BroadcastForwarder::AddrMsgHandler(const shill::RTNLMessage& msg) {
if (!msg.HasAttribute(IFA_LABEL)) {
LOG(ERROR) << "Address event message does not have IFA_LABEL";
return;
}
if (msg.mode() != shill::RTNLMessage::kModeAdd)
return;
shill::ByteString b(msg.GetAttribute(IFA_LABEL));
std::string ifname(reinterpret_cast<const char*>(
b.GetSubstring(0, IFNAMSIZ).GetConstData()));
if (ifname != dev_ifname_)
return;
// Interface address is added.
if (msg.HasAttribute(IFA_ADDRESS)) {
shill::ByteString b(msg.GetAttribute(IFA_ADDRESS));
memcpy(&dev_socket_->addr, b.GetConstData(), b.GetLength());
}
// Broadcast address is added.
if (msg.HasAttribute(IFA_BROADCAST)) {
shill::ByteString b(msg.GetAttribute(IFA_BROADCAST));
memcpy(&dev_socket_->broadaddr, b.GetConstData(), b.GetLength());
base::ScopedFD dev_fd(BindRaw(dev_ifname_));
if (!dev_fd.is_valid()) {
LOG(WARNING) << "Could not bind socket on " << dev_ifname_;
return;
}
dev_socket_.reset(new Socket(
std::move(dev_fd),
base::BindRepeating(&BroadcastForwarder::OnFileCanReadWithoutBlocking,
base::Unretained(this)),
dev_socket_->addr, dev_socket_->broadaddr));
}
}
base::ScopedFD BroadcastForwarder::Bind(const std::string& ifname,
uint16_t port) {
base::ScopedFD fd(socket(AF_INET, SOCK_DGRAM | SOCK_CLOEXEC, 0));
if (!fd.is_valid()) {
PLOG(ERROR) << "socket() failed for broadcast forwarder on " << ifname
<< " for port: " << port;
return base::ScopedFD();
}
struct ifreq ifr;
memset(&ifr, 0, sizeof(ifr));
strncpy(ifr.ifr_name, ifname.c_str(), IFNAMSIZ);
if (setsockopt(fd.get(), SOL_SOCKET, SO_BINDTODEVICE, &ifr, sizeof(ifr))) {
PLOG(ERROR) << "setsockopt(SOL_SOCKET) failed for broadcast forwarder on "
<< ifname << " for port: " << port;
return base::ScopedFD();
}
int on = 1;
if (setsockopt(fd.get(), SOL_SOCKET, SO_BROADCAST, &on, sizeof(on)) < 0) {
PLOG(ERROR) << "setsockopt(SO_BROADCAST) failed for broadcast forwarder on "
<< ifname << " for: " << port;
return base::ScopedFD();
}
if (setsockopt(fd.get(), SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) < 0) {
PLOG(ERROR) << "setsockopt(SO_REUSEADDR) failed for broadcast forwarder on "
<< ifname << " for: " << port;
return base::ScopedFD();
}
struct sockaddr_in bind_addr;
memset(&bind_addr, 0, sizeof(bind_addr));
bind_addr.sin_addr.s_addr = htonl(INADDR_ANY);
bind_addr.sin_family = AF_INET;
bind_addr.sin_port = htons(port);
if (bind(fd.get(), (const struct sockaddr*)&bind_addr, sizeof(bind_addr)) <
0) {
PLOG(ERROR) << "bind(" << port << ") failed for broadcast forwarder on "
<< ifname << " for: " << port;
return base::ScopedFD();
}
return fd;
}
base::ScopedFD BroadcastForwarder::BindRaw(const std::string& ifname) {
base::ScopedFD fd(
socket(AF_PACKET, SOCK_DGRAM | SOCK_CLOEXEC, htons(ETH_P_IP)));
if (!fd.is_valid()) {
PLOG(ERROR) << "socket() failed for raw socket";
return base::ScopedFD();
}
struct ifreq ifr;
memset(&ifr, 0, sizeof(ifr));
strncpy(ifr.ifr_name, ifname.c_str(), IFNAMSIZ);
if (ioctl(fd.get(), SIOCGIFINDEX, &ifr) < 0) {
PLOG(ERROR) << "SIOCGIFINDEX failed for " << ifname;
return base::ScopedFD();
}
struct sockaddr_ll bindaddr;
memset(&bindaddr, 0, sizeof(bindaddr));
bindaddr.sll_family = AF_PACKET;
bindaddr.sll_protocol = htons(ETH_P_IP);
bindaddr.sll_ifindex = ifr.ifr_ifindex;
if (bind(fd.get(), (const struct sockaddr*)&bindaddr, sizeof(bindaddr)) < 0) {
PLOG(ERROR) << "bind() failed for broadcast forwarder on " << ifname;
return base::ScopedFD();
}
Ioctl(fd.get(), ifname, SIOCGIFBRDADDR, &ifr);
uint32_t bcast_addr = GetIfreqBroadaddr(ifr);
if (!SetBcastSockFilter(fd.get(), bcast_addr)) {
return base::ScopedFD();
}
return fd;
}
bool BroadcastForwarder::AddGuest(const std::string& br_ifname) {
if (br_sockets_.find(br_ifname) != br_sockets_.end()) {
LOG(WARNING) << "Forwarding is already started between " << dev_ifname_
<< " and " << br_ifname;
return false;
}
base::ScopedFD br_fd(BindRaw(br_ifname));
if (!br_fd.is_valid()) {
LOG(WARNING) << "Could not bind socket on " << br_ifname;
return false;
}
struct ifreq ifr;
Ioctl(br_fd.get(), br_ifname, SIOCGIFADDR, &ifr);
uint32_t br_addr = GetIfreqAddr(ifr);
Ioctl(br_fd.get(), br_ifname, SIOCGIFBRDADDR, &ifr);
uint32_t br_broadaddr = GetIfreqBroadaddr(ifr);
Ioctl(br_fd.get(), br_ifname, SIOCGIFNETMASK, &ifr);
uint32_t br_netmask = GetIfreqNetmask(ifr);
std::unique_ptr<Socket> br_socket = std::make_unique<Socket>(
std::move(br_fd),
base::BindRepeating(&BroadcastForwarder::OnFileCanReadWithoutBlocking,
base::Unretained(this)),
br_addr, br_broadaddr, br_netmask);
br_sockets_.emplace(br_ifname, std::move(br_socket));
// Broadcast forwarder is not started yet.
if (dev_socket_ == nullptr) {
base::ScopedFD dev_fd(BindRaw(dev_ifname_));
if (!dev_fd.is_valid()) {
LOG(WARNING) << "Could not bind socket on " << dev_ifname_;
br_sockets_.clear();
return false;
}
Ioctl(dev_fd.get(), dev_ifname_, SIOCGIFADDR, &ifr);
uint32_t dev_addr = GetIfreqAddr(ifr);
Ioctl(dev_fd.get(), dev_ifname_, SIOCGIFBRDADDR, &ifr);
uint32_t dev_broadaddr = GetIfreqBroadaddr(ifr);
dev_socket_.reset(new Socket(
std::move(dev_fd),
base::BindRepeating(&BroadcastForwarder::OnFileCanReadWithoutBlocking,
base::Unretained(this)),
dev_addr, dev_broadaddr));
}
return true;
}
void BroadcastForwarder::RemoveGuest(const std::string& br_ifname) {
const auto& socket = br_sockets_.find(br_ifname);
if (socket == br_sockets_.end()) {
LOG(WARNING) << "Forwarding is not started between " << dev_ifname_
<< " and " << br_ifname;
return;
}
br_sockets_.erase(socket);
}
void BroadcastForwarder::OnFileCanReadWithoutBlocking(int fd) {
alignas(4) uint8_t buffer[kBufSize];
uint8_t* data = buffer + sizeof(iphdr) + sizeof(udphdr);
sockaddr_ll dst_addr;
struct iovec iov = {
.iov_base = buffer,
.iov_len = kBufSize,
};
msghdr hdr = {
.msg_name = &dst_addr,
.msg_namelen = sizeof(dst_addr),
.msg_iov = &iov,
.msg_iovlen = 1,
.msg_control = nullptr,
.msg_controllen = 0,
.msg_flags = 0,
};
ssize_t msg_len = recvmsg(fd, &hdr, 0);
if (msg_len < 0) {
// Ignore ENETDOWN: this can happen if the interface is not yet configured.
if (errno != ENETDOWN) {
PLOG(WARNING) << "recvmsg() failed";
}
return;
}
// These headers are taken directly from the buffer and is 4 bytes aligned.
struct iphdr* ip_hdr = (struct iphdr*)(buffer);
struct udphdr* udp_hdr = (struct udphdr*)(buffer + sizeof(iphdr));
// Drop fragmented packets.
if ((ntohs(ip_hdr->frag_off) & (kIpFragOffsetMask | IP_MF)) != 0)
return;
// Store the length of the message without its headers.
ssize_t len = ntohs(udp_hdr->len) - sizeof(udphdr);
// Validate UDP length.
if ((len + sizeof(udphdr) + sizeof(iphdr) > msg_len) || (len < 0))
return;
struct sockaddr_in fromaddr = {0};
fromaddr.sin_family = AF_INET;
fromaddr.sin_port = udp_hdr->uh_sport;
fromaddr.sin_addr.s_addr = ip_hdr->saddr;
struct sockaddr_in dst = {0};
dst.sin_family = AF_INET;
dst.sin_port = udp_hdr->uh_dport;
dst.sin_addr.s_addr = ip_hdr->daddr;
// Forward ingress traffic to guests.
if (fd == dev_socket_->fd.get()) {
// Prevent looped back broadcast packets to be forwarded.
if (fromaddr.sin_addr.s_addr == dev_socket_->addr)
return;
SendToGuests(buffer, len, dst);
return;
}
for (auto const& socket : br_sockets_) {
if (fd != socket.second->fd.get())
continue;
// Prevent looped back broadcast packets to be forwarded.
if (fromaddr.sin_addr.s_addr == socket.second->addr)
return;
// We are spoofing packets source IP to be the actual sender source IP.
// Prevent looped back broadcast packets by not forwarding anything from
// outside the interface netmask.
if ((fromaddr.sin_addr.s_addr & socket.second->netmask) !=
(socket.second->addr & socket.second->netmask))
return;
// Forward egress traffic from one guest to outside network.
SendToNetwork(ntohs(fromaddr.sin_port), data, len, dst);
}
}
bool BroadcastForwarder::SendToNetwork(uint16_t src_port,
const void* data,
ssize_t len,
const struct sockaddr_in& dst) {
base::ScopedFD temp_fd(Bind(dev_ifname_, src_port));
if (!temp_fd.is_valid()) {
LOG(WARNING) << "Could not bind socket on " << dev_ifname_ << " for port "
<< src_port;
return false;
}
struct sockaddr_in dev_dst = {0};
memcpy(&dev_dst, &dst, sizeof(sockaddr_in));
if (dev_dst.sin_addr.s_addr != kBcastAddr)
dev_dst.sin_addr.s_addr = dev_socket_->broadaddr;
if (sendto(temp_fd.get(), data, len, 0,
reinterpret_cast<const struct sockaddr*>(&dev_dst),
sizeof(struct sockaddr_in)) < 0) {
// Ignore ENETDOWN: this can happen if the interface is not yet configured.
if (errno != ENETDOWN) {
PLOG(WARNING) << "sendto() failed";
}
return false;
}
return true;
}
bool BroadcastForwarder::SendToGuests(const void* ip_pkt,
ssize_t len,
const struct sockaddr_in& dst) {
bool success = true;
base::ScopedFD raw(socket(AF_INET, SOCK_RAW | SOCK_CLOEXEC, IPPROTO_UDP));
if (!raw.is_valid()) {
PLOG(ERROR) << "socket() failed for raw socket";
return false;
}
int on = 1;
if (setsockopt(raw.get(), IPPROTO_IP, IP_HDRINCL, &on, sizeof(on)) < 0) {
PLOG(ERROR) << "setsockopt(IP_HDRINCL) failed";
return false;
}
if (setsockopt(raw.get(), SOL_SOCKET, SO_BROADCAST, &on, sizeof(on)) < 0) {
PLOG(ERROR) << "setsockopt(SO_BROADCAST) failed";
return false;
}
// Copy IP packet received by the lan interface and only change its
// destination address.
alignas(4) uint8_t buffer[kBufSize];
memset(buffer, 0, kBufSize);
memcpy(buffer, reinterpret_cast<const uint8_t*>(ip_pkt),
sizeof(iphdr) + sizeof(udphdr) + len);
// These headers are taken directly from the buffer and is 4 bytes aligned.
struct iphdr* ip_hdr = (struct iphdr*)buffer;
struct udphdr* udp_hdr = (struct udphdr*)(buffer + sizeof(struct iphdr));
ip_hdr->check = 0;
udp_hdr->check = 0;
struct sockaddr_in br_dst = {0};
memcpy(&br_dst, &dst, sizeof(struct sockaddr_in));
for (auto const& socket : br_sockets_) {
// Set destination address.
if (br_dst.sin_addr.s_addr != kBcastAddr) {
br_dst.sin_addr.s_addr = socket.second->broadaddr;
ip_hdr->daddr = socket.second->broadaddr;
ip_hdr->check = Ipv4Checksum(ip_hdr);
}
udp_hdr->check = Udpv4Checksum(ip_hdr, udp_hdr);
struct ifreq ifr;
memset(&ifr, 0, sizeof(ifr));
strncpy(ifr.ifr_name, socket.first.c_str(), IFNAMSIZ);
if (setsockopt(raw.get(), SOL_SOCKET, SO_BINDTODEVICE, &ifr, sizeof(ifr))) {
PLOG(ERROR) << "setsockopt(SOL_SOCKET) failed for broadcast forwarder on "
<< socket.first;
continue;
}
// Use already created broadcast fd.
if (sendto(raw.get(), buffer,
sizeof(struct iphdr) + sizeof(struct udphdr) + len, 0,
reinterpret_cast<const struct sockaddr*>(&br_dst),
sizeof(struct sockaddr_in)) < 0) {
PLOG(WARNING) << "sendto failed";
success = false;
}
}
return success;
}
} // namespace patchpanel