blob: f99d3dd5db7b81b4349c0f1140bad78da3ab8a6f [file] [log] [blame] [edit]
// Copyright 2023 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "vm_tools/concierge/mm/server.h"
#include <sys/socket.h>
#include <linux/vm_sockets.h> // Needs to come after sys/socket.h
#include <memory>
#include <utility>
#include <base/logging.h>
#include "vm_tools/concierge/mm/mm_service.h"
#include "vm_tools/concierge/mm/vm_socket.h"
using vm_tools::vm_memory_management::ConnectionHandshake;
namespace vm_tools::concierge::mm {
Server::Server(int port, SocketFactory socket_factory)
: port_(port), socket_factory_(socket_factory) {}
bool Server::StartListening() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (server_socket_) {
LOG(ERROR) << "Server is already listening.";
return false;
}
// Host is always allowed to connect.
RegisterVm(VMADDR_CID_LOCAL);
// Create a new server socket.
server_socket_ = socket_factory_.Run({});
// The necessary backlog size depends on how many VMs are managed by the
// service.
if (!server_socket_->Listen(port_, MmService::ManagedVms().size())) {
PLOG(ERROR) << "Failed to start listening for a connection on "
"Server VSOCK";
return false;
}
if (!server_socket_->OnReadable(
base::BindRepeating(&Server::HandleAccept, base::Unretained(this)))) {
PLOG(ERROR) << "Failed to watch for connections on VSOCK.";
return false;
}
LOG(INFO) << "Waiting for Server socket connections on "
"VSOCK port: "
<< port_;
return true;
}
void Server::RegisterVm(int vm_cid) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
allowed_vms_.emplace(vm_cid);
}
void Server::RemoveVm(int vm_cid) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
std::vector<int> connections_ids_to_erase{};
for (const auto& connection : connections_) {
if (vm_cid == connection.second.client.cid) {
client_disconnected_callback_.Run(connection.second.client);
connections_ids_to_erase.emplace_back(connection.first);
}
}
for (int connection_id : connections_ids_to_erase) {
connections_.erase(connection_id);
}
allowed_vms_.erase(vm_cid);
}
void Server::SetClientConnectionNotification(
ClientConnectionNotification callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
client_connection_callback_ = callback;
}
void Server::SetClientDisconnectedNotification(
ClientDisconnectedNotification callback) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
client_disconnected_callback_ = callback;
}
// static
std::unique_ptr<VmSocket> Server::SocketFactoryImpl(base::ScopedFD fd) {
return std::make_unique<VmSocket>(std::move(fd));
}
bool Server::SendNoPayloadPacket(const Connection& connection,
PacketType type) {
VmMemoryManagementPacket packet;
packet.set_type(type);
return connection.socket->WritePacket(packet);
}
const Connection* Server::GetConnection(int cid,
ConnectionType connection_type) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
for (const auto& connection : connections_) {
if (connection.second.client.cid == cid &&
connection.second.client.type == connection_type) {
return &connections_[connection.second.client.connection_id];
}
}
return nullptr;
}
void Server::RemoveConnection(int connection_id) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!connections_.contains(connection_id)) {
return;
}
client_disconnected_callback_.Run(connections_[connection_id].client);
connections_.erase(connection_id);
}
const Server::ClientConnectionNotification&
Server::GetClientConnectionCallback() {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wthread-safety-reference-return"
// TODO(b/331634345): Fix thread safety of reference return.
return client_connection_callback_;
#pragma clang diagnostic pop
}
const Server::ClientDisconnectedNotification&
Server::GetClientDisconnectedCallback() {
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wthread-safety-reference-return"
// TODO(b/331634345): Fix thread safety of reference return.
return client_disconnected_callback_;
#pragma clang diagnostic pop
}
void Server::HandleAccept() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (connections_.size() >= kMaxConnections) {
LOG(ERROR) << "Max connections reached. Ignoring new connection request.";
return;
}
Connection new_connection;
int connected_cid = 0;
base::ScopedFD accepted_fd = server_socket_->Accept(connected_cid);
// Use the accepted_fd as the ID for this connection. ids are not used as fds,
// but the fd is a unique identifier within the scope of this server and can
// be used as the id value.
int new_connection_id = accepted_fd.get();
new_connection.socket = socket_factory_.Run(std::move(accepted_fd));
if (!new_connection.socket->IsValid()) {
PLOG(ERROR) << "Server failed to accept new connection";
return;
}
if (!allowed_vms_.contains(connected_cid)) {
LOG(ERROR) << "Server rejecting connection from "
"un-registered VM: "
<< connected_cid;
return;
}
if (!new_connection.socket->OnReadable(
base::BindRepeating(&Server::HandleConnectionReadable,
base::Unretained(this), new_connection_id))) {
PLOG(ERROR) << "Failed to start watching reads from new client";
return;
}
new_connection.client.cid = connected_cid;
new_connection.client.connection_id = new_connection_id;
connections_.emplace(new_connection_id, std::move(new_connection));
}
void Server::HandleConnectionReadable(int connection_id) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!connections_.contains(connection_id)) {
LOG(ERROR) << "Received request from non-existant client";
return;
}
Connection& connection = connections_[connection_id];
VmMemoryManagementPacket received_packet;
if (!connection.socket->ReadPacket(received_packet)) {
LOG(ERROR) << "Failed to fully read packet from client: "
<< connection.client.cid;
RemoveConnection(connection_id);
return;
}
switch (received_packet.type()) {
case PacketType::PACKET_TYPE_HANDSHAKE:
return HandleConnectionHandshake(connection, received_packet);
default:
return HandlePacket(connection, received_packet);
}
}
void Server::HandleConnectionHandshake(Connection& connection,
const VmMemoryManagementPacket& packet) {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
if (!packet.has_handshake() ||
packet.handshake().type() ==
ConnectionType::CONNECTION_TYPE_UNSPECIFIED) {
LOG(ERROR) << "Invalid connection handshake received";
if (!SendNoPayloadPacket(connection,
PacketType::PACKET_TYPE_CONNECTION_NACK)) {
RemoveConnection(connection.client.connection_id);
}
return;
}
const ConnectionHandshake& handshake = packet.handshake();
connection.client.type = handshake.type();
if (!SendNoPayloadPacket(connection,
PacketType::PACKET_TYPE_CONNECTION_ACK)) {
LOG(ERROR) << "Failed to send CONNECTION_ACK to client: "
<< connection.client.cid;
RemoveConnection(connection.client.connection_id);
return;
}
client_connection_callback_.Run(connection.client);
LOG(INFO) << "Server accepted new connection. CID: " << connection.client.cid
<< " Type: " << ConnectionType_Name(connection.client.type);
}
} // namespace vm_tools::concierge::mm