blob: b69e5d050c1085ce6bcd7ab5f70d49ce5da0824a [file] [log] [blame]
// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <fcntl.h>
#include <poll.h>
#include <signal.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <syslog.h>
#include <unistd.h>
#include <linux/virtwl.h>
// This limits the number of child processes, which limits the number of
// concurrent connections to the proxy. This roughly corresponds to the number
// of concurrent wayland applications that can be running per machine.
#define MAX_CHILD_COUNT 128
static int handle_server_in(int server_fd, int client_fd) {
uint8_t ioctl_buf[4096];
struct virtwl_ioctl_recv* ioctl_recv = (struct virtwl_ioctl_recv*)ioctl_buf;
void* recv_data = ioctl_buf + sizeof(struct virtwl_ioctl_recv);
size_t max_recv_size = sizeof(ioctl_buf) - sizeof(struct virtwl_ioctl_recv);
char fd_buf[CMSG_LEN(sizeof(int) * VIRTWL_SEND_MAX_ALLOCS)];
ioctl_recv->len = max_recv_size;
int ret = ioctl(server_fd, VIRTWL_IOCTL_RECV, ioctl_recv);
if (ret) {
syslog(LOG_USER | LOG_DEBUG, "wayland server socket has hungup: %m");
return -1;
}
struct iovec buffer_iov;
buffer_iov.iov_base = recv_data;
buffer_iov.iov_len = ioctl_recv->len;
struct msghdr msg = {0};
msg.msg_iov = &buffer_iov;
msg.msg_iovlen = 1;
msg.msg_control = fd_buf;
// Simply counts how manye FDs the kernel gave us.
int fd_count;
for (fd_count = 0; fd_count < VIRTWL_SEND_MAX_ALLOCS; fd_count++) {
if (ioctl_recv->fds[fd_count] < 0)
break;
}
if (fd_count > 0) {
// Need to set msg_controllen so CMSG_FIRSTHDR will return the first
// cmsghdr. We copy every fd we just received from the ioctl into this
// cmsghdr.
msg.msg_controllen = sizeof(fd_buf);
struct cmsghdr* cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(fd_count * sizeof(int));
memcpy(CMSG_DATA(cmsg), ioctl_recv->fds, fd_count * sizeof(int));
msg.msg_controllen = cmsg->cmsg_len;
}
ssize_t write_size = sendmsg(client_fd, &msg, MSG_NOSIGNAL);
int i;
for (i = 0; i < fd_count; i++)
close(ioctl_recv->fds[i]);
if (write_size != ioctl_recv->len) {
syslog(LOG_USER | LOG_ERR, "failed sendmsg to client: %m");
return -1;
}
return 0;
}
static int handle_client_in(int server_fd, int client_fd) {
uint8_t ioctl_buf[4096];
struct virtwl_ioctl_send* ioctl_send = (struct virtwl_ioctl_send*)ioctl_buf;
void* send_data = ioctl_buf + sizeof(struct virtwl_ioctl_send);
size_t max_send_size = sizeof(ioctl_buf) - sizeof(struct virtwl_ioctl_send);
char fd_buf[CMSG_LEN(sizeof(int) * VIRTWL_SEND_MAX_ALLOCS)];
struct iovec buffer_iov;
buffer_iov.iov_base = send_data;
buffer_iov.iov_len = max_send_size;
struct msghdr msg = {0};
msg.msg_iov = &buffer_iov;
msg.msg_iovlen = 1;
msg.msg_control = fd_buf;
msg.msg_controllen = sizeof(fd_buf);
ssize_t read_size = recvmsg(client_fd, &msg, 0);
if (read_size == 0) {
syslog(LOG_USER | LOG_DEBUG, "client has hungup");
return -1;
}
if (read_size < 0) {
syslog(LOG_USER | LOG_ERR, "failed recvmsg from client: %m");
return -1;
}
for (int fd_idx = 0; fd_idx < VIRTWL_SEND_MAX_ALLOCS; fd_idx++)
ioctl_send->fds[fd_idx] = -1;
// If there were any FDs recv'd by recvmsg, there will be some data in the
// msg_control buffer. To get the FDs out we iterate all cmsghdr's within and
// unpack the FDs if the cmsghdr type is SCM_RIGHTS.
struct cmsghdr* cmsg = msg.msg_controllen != 0 ? CMSG_FIRSTHDR(&msg) : NULL;
for (int fd_idx = 0; cmsg; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS)
continue;
size_t cmsg_fd_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
// fd_idx will never exceed VIRTWL_SEND_MAX_ALLOCS because the
// control message buffer only allocates enough space for that many FDs.
memcpy(&ioctl_send->fds[fd_idx],
CMSG_DATA(cmsg),
cmsg_fd_count * sizeof(int));
fd_idx += cmsg_fd_count;
}
// The FDs and data were extracted from the recvmsg call into the ioctl_send
// structure which we now pass along to the kernel.
ioctl_send->len = read_size;
int ret = ioctl(server_fd, VIRTWL_IOCTL_SEND, ioctl_send);
if (ret)
syslog(LOG_USER | LOG_ERR, "failed to IOCTL_SEND to server: %m");
for (int fd_idx = 0; fd_idx < VIRTWL_SEND_MAX_ALLOCS; fd_idx++) {
int fd = ioctl_send->fds[fd_idx];
if (fd >= 0)
close(fd);
}
if (ret)
return -1;
return 0;
}
static int proxy_main(int wl_fd, int client_socket) {
int (*handlers[2])(int, int) = {handle_client_in, handle_server_in};
struct pollfd fds[2];
fds[0].fd = client_socket;
fds[0].events = POLLIN;
fds[1].fd = wl_fd;
fds[1].events = POLLIN;
int ret = 0;
while ((ret = poll(fds, 2, -1)) != -1) {
for (int i = 0; i < 2; i++) {
if ((fds[i].revents & POLLIN) == 0)
continue;
ret = handlers[i](wl_fd, client_socket);
if (ret)
goto end;
}
}
end:
close(wl_fd);
close(client_socket);
return ret;
}
static void empty_handler(int signum) {
}
int main(int argc, char** argv) {
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
snprintf(addr.sun_path,
sizeof(addr.sun_path) - 1,
"%s/wayland-0",
getenv("XDG_RUNTIME_DIR"));
size_t len = strlen(addr.sun_path) + sizeof(addr.sun_family);
unlink(addr.sun_path);
int server_socket = socket(AF_UNIX, SOCK_STREAM, 0);
if (server_socket < 0) {
syslog(LOG_USER | LOG_ERR, "failed to create listening socket: %m");
return 1;
}
if (bind(server_socket, (struct sockaddr*)&addr, len) != 0) {
syslog(LOG_USER | LOG_ERR, "failed to bind listening socket: %m");
return 1;
}
if (listen(server_socket, 8) != 0) {
syslog(LOG_USER | LOG_ERR, "failed to listen to socket: %m");
return 1;
}
struct sigaction child_action;
memset(&child_action, 0, sizeof(child_action));
sigemptyset(&child_action.sa_mask);
child_action.sa_handler = empty_handler;
sigaction(SIGCHLD, &child_action, NULL);
int child_count = 0;
for (;;) {
// The child_count is used to limit the number of child processes that one
// proxy will handle before it will start dropping new connections. Because
// accept is the only blocking call made in the main loop, we can reap
// children and update the child_count just before to get an accurate count.
// If children die while blocked in accept, the empty signal handler will
// ensure the accept gets interrupted, which we check for in order to
// restart the main loop. There is an intrinsic race condition in which a
// child dies just after accept returns a new connection in which case
// child_count would be inaccurate and lead to an inappropriate drop of the
// new connection. However, this race condition is unavoidable and will not
// lead to a permanent DoS as the child_count will become accurate on the
// next iteration.
while (waitpid(-1, NULL, WNOHANG) > 0) {
if (child_count > 0) {
child_count--;
} else {
syslog(LOG_USER | LOG_WARNING, "reaped more children than spawned");
}
}
struct sockaddr_un remote_addr;
socklen_t remote_addr_len = sizeof(struct sockaddr_un);
int client_socket =
accept(server_socket, (struct sockaddr*)&remote_addr, &remote_addr_len);
if (client_socket == -1) {
// An EINTR probably means that the SIGCHLD handler was called and
// children need to be reaped.
if (errno == EINTR)
continue;
syslog(LOG_USER | LOG_ERR, "failed to accept incoming socket: %m");
return 1;
}
if (child_count >= MAX_CHILD_COUNT) {
syslog(LOG_USER | LOG_WARNING,
"dropping excessive number of client connections");
close(client_socket);
continue;
}
static char log_ident[32] = "virtwl";
struct sockaddr_storage peer_name;
socklen_t peer_name_len = sizeof(struct sockaddr_storage);
int ret = getsockname(client_socket, (struct sockaddr*)&peer_name,
&peer_name_len);
if (ret == 0) {
struct sockaddr_un* peer_name_un = (struct sockaddr_un*)&peer_name;
syslog(LOG_USER | LOG_INFO, "client connected: %s",
peer_name_un->sun_path);
snprintf(log_ident, sizeof(log_ident) - 1, "virtwl-%s",
peer_name_un->sun_path);
} else {
syslog(LOG_USER | LOG_INFO, "client connected: error getting name: %m");
}
// We use fork here so that each client connection is isolated from crashes
// in the other, and so that each gets it's own set of FDs.
ret = fork();
if (ret == 0) { /* child */
openlog(log_ident, LOG_PERROR | LOG_PID, LOG_USER);
close(server_socket);
int wl_fd = open("/dev/wl0", O_RDWR);
if (wl_fd < 0) {
syslog(LOG_USER | LOG_ERR, "failed to open wl0: %m");
return 1;
}
struct virtwl_ioctl_new new_ctx = {
.type = VIRTWL_IOCTL_NEW_CTX,
.fd = -1,
.flags = 0,
.size = 0,
};
ret = ioctl(wl_fd, VIRTWL_IOCTL_NEW, &new_ctx);
close(wl_fd);
if (ret) {
syslog(LOG_USER | LOG_ERR, "failed to create new wayland context: %m");
return 1;
}
return proxy_main(new_ctx.fd, client_socket);
} else if (ret == -1) {
syslog(LOG_USER | LOG_ERR, "failed to fork client handler: %m");
}
child_count++;
close(client_socket);
}
return 0;
}