blob: b3ab82abca439cb2b6d7cd9933000ce5c8233cf6 [file] [log] [blame] [edit]
// Copyright 2023 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef VM_TOOLS_CONCIERGE_MM_SERVER_H_
#define VM_TOOLS_CONCIERGE_MM_SERVER_H_
#include <memory>
#include <string>
#include <vector>
#include <base/containers/flat_map.h>
#include <base/containers/flat_set.h>
#include <base/functional/callback.h>
#include <vm_memory_management/vm_memory_management.pb.h>
#include "vm_tools/concierge/mm/vm_socket.h"
using vm_tools::vm_memory_management::ConnectionType;
using vm_tools::vm_memory_management::PacketType;
using vm_tools::vm_memory_management::VmMemoryManagementPacket;
namespace vm_tools::concierge::mm {
struct Client {
// The CID of the client. VMADDR_CID_LOCAL for host clients, and a valid VM
// CID for VM clients. Initialized to -2 which is an invalid CID.
int cid = -2;
// The connection id of the client. This ID is unique in the scope of this
// server and assigned when a client connects. It is necessary because a given
// context can have more than one client. Initialized to -1 which is an
// invalid connection_id.
int connection_id = -1;
// The type of the connection.
ConnectionType type = ConnectionType::CONNECTION_TYPE_UNSPECIFIED;
};
// Represents a connection to the server.
struct Connection {
// The specific client that is connected.
Client client{};
// The socket open to the connection.
std::unique_ptr<VmSocket> socket{};
};
// The Server interface is a server that listens on the
// supplied socket for connections from Clients. The
// connection handshakes and packet transport are all handled by the base class.
// Implementations can implement handling logic for specific message types.
class Server {
public:
using SocketFactory =
base::RepeatingCallback<std::unique_ptr<VmSocket>(base::ScopedFD)>;
explicit Server(int port, SocketFactory socket_factory);
virtual ~Server() = default;
// Starts listening for connections on the specified port.
bool StartListening();
// Registers a VM with the server. Only connections from registered VMs will
// be accepted.
void RegisterVm(int vm_cid);
// Perform cleanup operations for a VM when it is shutting down.
void RemoveVm(int vm_cid);
// START: Event callbacks.
// Sets the callback to be run when a new client connects to the server.
using ClientConnectionNotification = base::RepeatingCallback<void(Client)>;
void SetClientConnectionNotification(ClientConnectionNotification callback);
// Sets the callback to be run when a new client disconnects from the
// server.
using ClientDisconnectedNotification = base::RepeatingCallback<void(Client)>;
void SetClientDisconnectedNotification(
ClientDisconnectedNotification callback);
// END: Event Callbacks.
protected:
// Creates socket instances for use by the server.
static std::unique_ptr<VmSocket> SocketFactoryImpl(base::ScopedFD fd);
// Sends a packet with the specified type and no payload to the specified
// connection.
static bool SendNoPayloadPacket(const Connection& connection,
PacketType type);
// Gets a connection matching the specified parameters, or nullptr if none
// exists.
const Connection* GetConnection(int cid, ConnectionType type);
// Removes the specified connection from the server.
void RemoveConnection(int connection_id);
// Returns the ClientConnectionNotification for this server.
const ClientConnectionNotification& GetClientConnectionCallback();
// Returns the ClientDisconnectedNotification for this server.
const ClientDisconnectedNotification& GetClientDisconnectedCallback();
// Performs the necessary actions for the received packet.
virtual void HandlePacket(
const Connection& connection,
const VmMemoryManagementPacket& received_packet) = 0;
// Ensure calls are made on the correct sequence.
SEQUENCE_CHECKER(sequence_checker_);
private:
// Handles accepting a new connection.
void HandleAccept();
// Handles reading from a client connection.
void HandleConnectionReadable(int connection_id);
// Handles performing the initial connection handshake.
void HandleConnectionHandshake(Connection& connection,
const VmMemoryManagementPacket& packet);
// The port on which this server listens.
const int port_;
// Creates socket instances.
const SocketFactory socket_factory_;
// The socket used to accept connections.
std::unique_ptr<VmSocket> server_socket_
GUARDED_BY_CONTEXT(sequence_checker_){};
// Contains the VM CIDs that the server is allowed to accept a connection
// from.
base::flat_set<int> allowed_vms_ GUARDED_BY_CONTEXT(sequence_checker_){};
// The maximum number of simultaneous connections that are allowed on one
// server at a given time.
static constexpr size_t kMaxConnections = 8;
base::flat_map<int /* connection_id */, Connection> connections_
GUARDED_BY_CONTEXT(sequence_checker_){};
// Event Callbacks.
ClientConnectionNotification client_connection_callback_
GUARDED_BY_CONTEXT(sequence_checker_) = base::DoNothing();
ClientDisconnectedNotification client_disconnected_callback_
GUARDED_BY_CONTEXT(sequence_checker_) = base::DoNothing();
};
} // namespace vm_tools::concierge::mm
#endif // VM_TOOLS_CONCIERGE_MM_SERVER_H_