blob: feda8973d67f7ec3cef14e9197f177deec70c0ab [file] [log] [blame]
// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <fcntl.h>
#include <limits.h>
#include <poll.h>
#include <pthread.h>
#include <signal.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <syslog.h>
#include <unistd.h>
#include <linux/virtwl.h>
// This limits the number of child processes, which limits the number of
// concurrent connections to the proxy. This roughly corresponds to the number
// of concurrent wayland applications that can be running per machine.
#define MAX_CHILD_COUNT 128
// No matter what kind of virtwl fd we are given, a zero-sized ioctl send should
// succeed. This property is used to determine if the given fd is a virtwl fd.
static bool is_virtwl_fd(int fd) {
struct virtwl_ioctl_txn ioctl_send;
for (int fd_idx = 0; fd_idx < VIRTWL_SEND_MAX_ALLOCS; fd_idx++)
ioctl_send.fds[fd_idx] = -1;
ioctl_send.len = 0;
return ioctl(fd, VIRTWL_IOCTL_SEND, &ioctl_send) == 0;
}
static void* pipe_proxy_routine(void* args) {
int* fds = (int*)args;
int in_pipe = fds[0];
int out_pipe = fds[1];
free(fds);
uint8_t buf[PIPE_BUF];
for (;;) {
int ret = read(in_pipe, buf, sizeof(buf));
if (ret == -1) {
syslog(LOG_USER | LOG_ERR, "error reading from input pipe: %m");
break;
}
// Check for hangup.
if (ret == 0)
break;
size_t count = ret;
ret = write(out_pipe, buf, count);
if (ret == -1) {
syslog(LOG_USER | LOG_ERR, "error writing to output pipe: %m");
break;
}
// Check for hangup
if (ret != count) {
syslog(LOG_USER | LOG_ERR, "incomplete write to output pipe %d",
out_pipe);
break;
}
}
close(in_pipe);
close(out_pipe);
return NULL;
}
static int launch_pipe_proxy(int in_fd, int out_fd) {
int* fds = calloc(2, sizeof(int));
if (!fds)
return ENOMEM;
fds[0] = in_fd;
fds[1] = out_fd;
pthread_attr_t attr;
pthread_attr_init(&attr);
pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_DETACHED);
pthread_t thread;
int ret = pthread_create(&thread, &attr, pipe_proxy_routine, fds);
if (ret) {
free(fds);
syslog(LOG_USER | LOG_DEBUG, "failed to create pipe proxy thread: %m");
return ret;
}
return 0;
}
static int handle_server_in(int wl0_fd, int server_fd, int client_fd) {
uint8_t ioctl_buf[4096];
struct virtwl_ioctl_txn* ioctl_recv = (struct virtwl_ioctl_txn*)ioctl_buf;
void* recv_data = ioctl_buf + sizeof(struct virtwl_ioctl_txn);
size_t max_recv_size = sizeof(ioctl_buf) - sizeof(struct virtwl_ioctl_txn);
char fd_buf[CMSG_LEN(sizeof(int) * VIRTWL_SEND_MAX_ALLOCS)];
ioctl_recv->len = max_recv_size;
int ret = ioctl(server_fd, VIRTWL_IOCTL_RECV, ioctl_recv);
if (ret) {
syslog(LOG_USER | LOG_DEBUG, "wayland server socket has hungup: %m");
return -1;
}
struct iovec buffer_iov;
buffer_iov.iov_base = recv_data;
buffer_iov.iov_len = ioctl_recv->len;
struct msghdr msg = {0};
msg.msg_iov = &buffer_iov;
msg.msg_iovlen = 1;
msg.msg_control = fd_buf;
// Simply counts how manye FDs the kernel gave us.
int fd_count;
for (fd_count = 0; fd_count < VIRTWL_SEND_MAX_ALLOCS; fd_count++) {
if (ioctl_recv->fds[fd_count] < 0)
break;
}
if (fd_count > 0) {
// Need to set msg_controllen so CMSG_FIRSTHDR will return the first
// cmsghdr. We copy every fd we just received from the ioctl into this
// cmsghdr.
msg.msg_controllen = sizeof(fd_buf);
struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(fd_count * sizeof(int));
memcpy(CMSG_DATA(cmsg), ioctl_recv->fds, fd_count * sizeof(int));
msg.msg_controllen = cmsg->cmsg_len;
}
ssize_t write_size = sendmsg(client_fd, &msg, MSG_NOSIGNAL);
int i;
for (i = 0; i < fd_count; i++)
close(ioctl_recv->fds[i]);
if (write_size != ioctl_recv->len) {
syslog(LOG_USER | LOG_ERR, "failed sendmsg to client: %m");
return -1;
}
return 0;
}
static int handle_client_in(int wl0_fd, int server_fd, int client_fd) {
uint8_t ioctl_buf[4096];
struct virtwl_ioctl_txn* ioctl_send = (struct virtwl_ioctl_txn*)ioctl_buf;
void* send_data = ioctl_buf + sizeof(struct virtwl_ioctl_txn);
size_t max_send_size = sizeof(ioctl_buf) - sizeof(struct virtwl_ioctl_txn);
char fd_buf[CMSG_LEN(sizeof(int) * VIRTWL_SEND_MAX_ALLOCS)];
uint8_t retain_fds[VIRTWL_SEND_MAX_ALLOCS] = {0};
struct iovec buffer_iov;
buffer_iov.iov_base = send_data;
buffer_iov.iov_len = max_send_size;
struct msghdr msg = {0};
msg.msg_iov = &buffer_iov;
msg.msg_iovlen = 1;
msg.msg_control = fd_buf;
msg.msg_controllen = sizeof(fd_buf);
ssize_t read_size = recvmsg(client_fd, &msg, 0);
if (read_size == 0) {
syslog(LOG_USER | LOG_DEBUG, "client has hungup");
return -1;
}
if (read_size < 0) {
syslog(LOG_USER | LOG_ERR, "failed recvmsg from client: %m");
return -1;
}
for (int fd_idx = 0; fd_idx < VIRTWL_SEND_MAX_ALLOCS; fd_idx++)
ioctl_send->fds[fd_idx] = -1;
// If there were any FDs recv'd by recvmsg, there will be some data in the
// msg_control buffer. To get the FDs out we iterate all cmsghdr's within and
// unpack the FDs if the cmsghdr type is SCM_RIGHTS.
struct cmsghdr* cmsg = msg.msg_controllen != 0 ? CMSG_FIRSTHDR(&msg) : NULL;
for (int fd_idx = 0; cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS)
continue;
size_t cmsg_fd_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
// fd_idx will never exceed VIRTWL_SEND_MAX_ALLOCS because the
// control message buffer only allocates enough space for that many FDs.
memcpy(&ioctl_send->fds[fd_idx], CMSG_DATA(cmsg),
cmsg_fd_count * sizeof(int));
fd_idx += cmsg_fd_count;
}
for (int fd_idx = 0; fd_idx < VIRTWL_SEND_MAX_ALLOCS; fd_idx++) {
int fd = ioctl_send->fds[fd_idx];
if (fd < 0)
break;
if (is_virtwl_fd(fd))
continue;
// If the client sends us a non-virtwl FD, it's likely some kind of pipe
// that we can manually proxy in another thread.
struct virtwl_ioctl_new new_pipe = {
.type = 0,
.fd = -1,
.flags = 0,
.size = 0,
};
int flags = fcntl(fd, F_GETFL) & O_ACCMODE;
switch (flags) {
case O_RDONLY:
new_pipe.type = VIRTWL_IOCTL_NEW_PIPE_WRITE;
break;
case O_WRONLY:
// virtwl does not support read/write pipes but pipes sent from the client
// are likely intended to be written to by the remote end.
case O_RDWR:
new_pipe.type = VIRTWL_IOCTL_NEW_PIPE_READ;
break;
default:
continue;
}
int ret = ioctl(wl0_fd, VIRTWL_IOCTL_NEW, &new_pipe);
if (ret) {
syslog(LOG_USER | LOG_ERR, "failed to create virtwl pipe: %m");
return -1;
}
if (flags == O_RDONLY)
ret = launch_pipe_proxy(fd, new_pipe.fd);
else
ret = launch_pipe_proxy(new_pipe.fd, fd);
if (ret) {
close(new_pipe.fd);
return -1;
} else {
ioctl_send->fds[fd_idx] = new_pipe.fd;
retain_fds[fd_idx] = 1;
}
}
// The FDs and data were extracted from the recvmsg call into the ioctl_send
// structure which we now pass along to the kernel.
ioctl_send->len = read_size;
int ret = ioctl(server_fd, VIRTWL_IOCTL_SEND, ioctl_send);
if (ret)
syslog(LOG_USER | LOG_ERR, "failed to IOCTL_SEND to server: %m");
for (int fd_idx = 0; fd_idx < VIRTWL_SEND_MAX_ALLOCS; fd_idx++) {
int fd = ioctl_send->fds[fd_idx];
if (fd >= 0 && !retain_fds[fd_idx])
close(fd);
}
if (ret)
return -1;
return 0;
}
static int proxy_main(int wl0_fd, int wl_fd, int client_socket) {
int (*handlers[2])(int, int, int) = {handle_client_in, handle_server_in};
struct pollfd fds[2];
fds[0].fd = client_socket;
fds[0].events = POLLIN;
fds[1].fd = wl_fd;
fds[1].events = POLLIN;
int ret = 0;
while ((ret = poll(fds, 2, -1)) != -1) {
for (int i = 0; i < 2; i++) {
if ((fds[i].revents & POLLIN) == 0) {
if ((fds[i].revents & POLLHUP) == 0)
continue;
else
goto end;
}
ret = handlers[i](wl0_fd, wl_fd, client_socket);
if (ret)
goto end;
}
}
end:
close(wl0_fd);
close(wl_fd);
close(client_socket);
return ret;
}
static void empty_handler(int signum) {}
int main(int argc, char** argv) {
// Handle broken pipes without signals that kill the entire process.
signal(SIGPIPE, SIG_IGN);
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
snprintf(addr.sun_path, sizeof(addr.sun_path) - 1, "%s/wayland-0",
getenv("XDG_RUNTIME_DIR"));
socklen_t len = strlen(addr.sun_path) + sizeof(addr.sun_family);
unlink(addr.sun_path);
// The socket must be world-writable to be accessible by a container with
// user namespaces.
umask(0);
int server_socket = socket(AF_UNIX, SOCK_STREAM | SOCK_CLOEXEC, 0);
if (server_socket < 0) {
syslog(LOG_USER | LOG_ERR, "failed to create listening socket: %m");
return 1;
}
if (bind(server_socket, (struct sockaddr*)&addr, len) != 0) {
syslog(LOG_USER | LOG_ERR, "failed to bind listening socket: %m");
return 1;
}
if (listen(server_socket, 8) != 0) {
syslog(LOG_USER | LOG_ERR, "failed to listen to socket: %m");
return 1;
}
struct sigaction child_action;
memset(&child_action, 0, sizeof(child_action));
sigemptyset(&child_action.sa_mask);
child_action.sa_handler = empty_handler;
sigaction(SIGCHLD, &child_action, NULL);
int child_count = 0;
for (;;) {
// The child_count is used to limit the number of child processes that one
// proxy will handle before it will start dropping new connections. Because
// accept is the only blocking call made in the main loop, we can reap
// children and update the child_count just before to get an accurate count.
// If children die while blocked in accept, the empty signal handler will
// ensure the accept gets interrupted, which we check for in order to
// restart the main loop. There is an intrinsic race condition in which a
// child dies just after accept returns a new connection in which case
// child_count would be inaccurate and lead to an inappropriate drop of the
// new connection. However, this race condition is unavoidable and will not
// lead to a permanent DoS as the child_count will become accurate on the
// next iteration.
while (waitpid(-1, NULL, WNOHANG) > 0) {
if (child_count > 0) {
child_count--;
} else {
syslog(LOG_USER | LOG_WARNING, "reaped more children than spawned");
}
}
struct sockaddr_un remote_addr;
socklen_t remote_addr_len = sizeof(struct sockaddr_un);
int client_socket =
accept(server_socket, (struct sockaddr*)&remote_addr, &remote_addr_len);
if (client_socket == -1) {
// An EINTR probably means that the SIGCHLD handler was called and
// children need to be reaped.
if (errno == EINTR)
continue;
syslog(LOG_USER | LOG_ERR, "failed to accept incoming socket: %m");
return 1;
}
if (child_count >= MAX_CHILD_COUNT) {
syslog(LOG_USER | LOG_WARNING,
"dropping excessive number of client connections");
close(client_socket);
continue;
}
static char log_ident[32] = "virtwl";
struct sockaddr_storage peer_name;
socklen_t peer_name_len = sizeof(struct sockaddr_storage);
int ret = getsockname(client_socket, (struct sockaddr*)&peer_name,
&peer_name_len);
if (ret == 0) {
struct sockaddr_un* peer_name_un = (struct sockaddr_un*)&peer_name;
syslog(LOG_USER | LOG_INFO, "client connected: %s",
peer_name_un->sun_path);
snprintf(log_ident, sizeof(log_ident) - 1, "virtwl-%s",
peer_name_un->sun_path);
} else {
syslog(LOG_USER | LOG_INFO, "client connected: error getting name: %m");
}
// We use fork here so that each client connection is isolated from crashes
// in the other, and so that each gets it's own set of FDs.
ret = fork();
if (ret == 0) { /* child */
openlog(log_ident, LOG_PERROR | LOG_PID, LOG_USER);
close(server_socket);
int wl_fd = open("/dev/wl0", O_RDWR | O_CLOEXEC);
if (wl_fd < 0) {
syslog(LOG_USER | LOG_ERR, "failed to open wl0: %m");
return 1;
}
struct virtwl_ioctl_new new_ctx = {
.type = VIRTWL_IOCTL_NEW_CTX,
.fd = -1,
.flags = 0,
.size = 0,
};
ret = ioctl(wl_fd, VIRTWL_IOCTL_NEW, &new_ctx);
if (ret) {
syslog(LOG_USER | LOG_ERR, "failed to create new wayland context: %m");
return 1;
}
return proxy_main(wl_fd, new_ctx.fd, client_socket);
} else if (ret == -1) {
syslog(LOG_USER | LOG_ERR, "failed to fork client handler: %m");
}
child_count++;
close(client_socket);
}
}