blob: 5676eea45ce00ea1c10c5045294eb5e7bf95a8f7 [file] [log] [blame]
// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "smbprovider/mount_manager.h"
#include <algorithm>
#include <base/strings/string_util.h>
#include <base/time/tick_clock.h>
#include "smbprovider/smb_credential.h"
#include "smbprovider/smbprovider_helper.h"
namespace smbprovider {
namespace {
// Returns true if |buffer_length| is large enough to contain |str|.
bool CanBufferHoldString(const std::string& str, int32_t buffer_length) {
return static_cast<int32_t>(str.size()) + 1 <= buffer_length;
}
// Returns true if |buffer_length| is large enough to contain |password|.
bool CanBufferHoldPassword(
const std::unique_ptr<password_provider::Password>& password,
int32_t buffer_length) {
DCHECK(password);
return static_cast<int32_t>(password->size()) + 1 <= buffer_length;
}
// Sets the first element in the buffer to be a null terminator.
void SetBufferEmpty(char* buffer) {
DCHECK(buffer);
buffer[0] = '\0';
}
// Copies |str| to |buffer| and adds a null terminator at the end.
void CopyStringToBuffer(const std::string& str, char* buffer) {
DCHECK(buffer);
strncpy(buffer, str.c_str(), str.size());
buffer[str.size()] = '\0';
}
// Copies |password| to |buffer| and adds a null terminator at the end.
void CopyPasswordToBuffer(
const std::unique_ptr<password_provider::Password>& password,
char* buffer) {
DCHECK(password);
DCHECK(buffer);
strncpy(buffer, password->GetRaw(), password->size());
buffer[password->size()] = '\0';
}
// Checks that the credential can be inputted given the buffer sizes. Returns
// false if the buffers are too small or if the credential is empty.
bool CanInputCredential(int32_t workgroup_length,
int32_t username_length,
int32_t password_length,
const SmbCredential& credential) {
if (!CanBufferHoldString(credential.workgroup, workgroup_length) ||
!CanBufferHoldString(credential.username, username_length)) {
LOG(ERROR) << "Credential buffers are too small for input.";
return false;
}
if (credential.password &&
!CanBufferHoldPassword(credential.password, password_length)) {
LOG(ERROR) << "Password buffer is too small for input.";
return false;
}
return true;
}
// Populates the |credential| into the specified buffers. CanInputCredential()
// should be called first in order to verify the buffers can contain the
// credential.
void PopulateCredential(const SmbCredential& credential,
char* workgroup_buffer,
char* username_buffer,
char* password_buffer) {
DCHECK(workgroup_buffer);
DCHECK(username_buffer);
DCHECK(password_buffer);
CopyStringToBuffer(credential.workgroup, workgroup_buffer);
CopyStringToBuffer(credential.username, username_buffer);
const bool empty_password = !credential.password;
if (empty_password) {
SetBufferEmpty(password_buffer);
} else {
CopyPasswordToBuffer(credential.password, password_buffer);
}
}
// Gets a password_provider::Password object from |password_fd|. The data has to
// be in the format of "{password_length}{password}". If the read fails, this
// returns an empty unique_ptr.
std::unique_ptr<password_provider::Password> GetPassword(
const base::ScopedFD& password_fd) {
size_t password_length = 0;
// Read sizeof(size_t) bytes from the file to get the password length.
bool success = base::ReadFromFD(password_fd.get(),
reinterpret_cast<char*>(&password_length),
sizeof(password_length));
if (!success) {
LOG(ERROR) << "Could not read password from file.";
return std::unique_ptr<password_provider::Password>();
}
if (password_length == 0) {
// Return empty password since there is no password.
return std::unique_ptr<password_provider::Password>();
}
return password_provider::Password::CreateFromFileDescriptor(
password_fd.get(), password_length);
}
} // namespace
MountManager::MountManager(std::unique_ptr<base::TickClock> tick_clock,
SambaInterfaceFactory samba_interface_factory)
: tick_clock_(std::move(tick_clock)),
samba_interface_factory_(std::move(samba_interface_factory)) {
system_samba_interface_ = CreateSambaInterface();
}
MountManager::~MountManager() = default;
bool MountManager::IsAlreadyMounted(int32_t mount_id) const {
auto mount_iter = mounts_.Find(mount_id);
if (mount_iter == mounts_.End()) {
return false;
}
DCHECK_EQ(mounted_share_paths_.count(mount_iter->second.mount_root), 1);
return true;
}
bool MountManager::IsAlreadyMounted(const std::string& mount_root) const {
bool has_credential = mounted_share_paths_.count(mount_root) > 0;
if (!has_credential) {
DCHECK(!ExistsInMounts(mount_root));
return false;
}
DCHECK(ExistsInMounts(mount_root));
return true;
}
bool MountManager::AddMount(const std::string& mount_root,
const std::string& workgroup,
const std::string& username,
const base::ScopedFD& password_fd,
int32_t* mount_id) {
DCHECK(mount_id);
if (IsAlreadyMounted(mount_root)) {
return false;
}
can_remount_ = false;
SmbCredential credential(workgroup, username, GetPassword(password_fd));
*mount_id =
mounts_.Insert(CreateMountInfo(mount_root, std::move(credential)));
AddSambaInterfaceIdToSambaInterfaceMap(*mount_id);
mounted_share_paths_.insert(mount_root);
return true;
}
bool MountManager::Remount(const std::string& mount_root,
int32_t mount_id,
const std::string& workgroup,
const std::string& username,
const base::ScopedFD& password_fd) {
DCHECK(can_remount_);
DCHECK(!IsAlreadyMounted(mount_id));
DCHECK_GE(mount_id, 0);
if (IsAlreadyMounted(mount_root)) {
return false;
}
SmbCredential credential(workgroup, username, GetPassword(password_fd));
mounts_.InsertWithSpecificId(
mount_id, CreateMountInfo(mount_root, std::move(credential)));
AddSambaInterfaceIdToSambaInterfaceMap(mount_id);
mounted_share_paths_.insert(mount_root);
return true;
}
bool MountManager::RemoveMount(int32_t mount_id) {
auto mount_iter = mounts_.Find(mount_id);
if (mount_iter == mounts_.End()) {
return false;
}
DeleteSambaInterfaceIdFromSambaInterfaceMap(mount_id);
bool path_removed = mounted_share_paths_.erase(mount_iter->second.mount_root);
DCHECK(path_removed);
mounts_.Remove(mount_iter->first);
return true;
}
bool MountManager::GetFullPath(int32_t mount_id,
const std::string& entry_path,
std::string* full_path) const {
DCHECK(full_path);
auto mount_iter = mounts_.Find(mount_id);
if (mount_iter == mounts_.End()) {
return false;
}
*full_path = AppendPath(mount_iter->second.mount_root, entry_path);
return true;
}
bool MountManager::GetMetadataCache(int32_t mount_id,
MetadataCache** cache) const {
DCHECK(cache);
auto mount_iter = mounts_.Find(mount_id);
if (mount_iter == mounts_.End()) {
return false;
}
*cache = mount_iter->second.cache.get();
DCHECK(*cache);
return true;
}
std::string MountManager::GetRelativePath(int32_t mount_id,
const std::string& full_path) const {
auto mount_iter = mounts_.Find(mount_id);
DCHECK(mount_iter != mounts_.End());
DCHECK(StartsWith(full_path, mount_iter->second.mount_root,
base::CompareCase::INSENSITIVE_ASCII));
return full_path.substr(mount_iter->second.mount_root.length());
}
bool MountManager::GetSambaInterface(int32_t mount_id,
SambaInterface** samba_interface) const {
DCHECK(samba_interface);
auto mount_iter = mounts_.Find(mount_id);
if (mount_iter == mounts_.End()) {
return false;
}
*samba_interface = mount_iter->second.samba_interface.get();
DCHECK(*samba_interface);
return true;
}
SambaInterface* MountManager::GetSystemSambaInterface() const {
return system_samba_interface_.get();
}
const SmbCredential& MountManager::GetCredential(
SambaInterface::SambaInterfaceId samba_interface_id) const {
DCHECK_NE(samba_interface_map_.count(samba_interface_id), 0);
// Double lookup of SambaInterfaceId => MountId followed by MountId =>
// MountInfo.credential
const int32_t mount_id = samba_interface_map_.at(samba_interface_id);
DCHECK(mounts_.Contains(mount_id));
return mounts_.Find(mount_id)->second.credential;
}
std::unique_ptr<SambaInterface> MountManager::CreateSambaInterface() {
return samba_interface_factory_.Run(this);
}
bool MountManager::GetAuthentication(
SambaInterface::SambaInterfaceId samba_interface_id,
const std::string& share_path,
char* workgroup,
int32_t workgroup_length,
char* username,
int32_t username_length,
char* password,
int32_t password_length) const {
DCHECK_GT(workgroup_length, 0);
DCHECK_GT(username_length, 0);
DCHECK_GT(password_length, 0);
if (samba_interface_map_.count(samba_interface_id) == 0) {
LOG(ERROR) << "Credentials not found for " << share_path;
SetBufferEmpty(workgroup);
SetBufferEmpty(username);
SetBufferEmpty(password);
return false;
}
const SmbCredential& credential = GetCredential(samba_interface_id);
if (!CanInputCredential(workgroup_length, username_length, password_length,
credential)) {
LOG(ERROR) << "Buffers cannot support credentials for " << share_path;
SetBufferEmpty(workgroup);
SetBufferEmpty(username);
SetBufferEmpty(password);
return false;
}
PopulateCredential(credential, workgroup, username, password);
return true;
}
const SmbCredential& MountManager::GetCredentialFromMountIdForTesting(
int32_t mount_id) const {
auto mount_iter = mounts_.Find(mount_id);
DCHECK(mount_iter != mounts_.End());
return mount_iter->second.credential;
}
MountManager::MountInfo MountManager::CreateMountInfo(
const std::string& mount_root, SmbCredential credential) {
return MountInfo(mount_root, tick_clock_.get(), std::move(credential),
CreateSambaInterface());
}
SambaInterface::SambaInterfaceId MountManager::GetSystemSambaInterfaceId() {
return system_samba_interface_->GetSambaInterfaceId();
}
SambaInterface::SambaInterfaceId MountManager::GetSambaInterfaceIdForMountId(
int32_t mount_id) const {
DCHECK(mounts_.Contains(mount_id));
auto mount_iter = mounts_.Find(mount_id);
return mount_iter->second.samba_interface->GetSambaInterfaceId();
}
void MountManager::AddSambaInterfaceIdToSambaInterfaceMap(int32_t mount_id) {
const SambaInterface::SambaInterfaceId samba_interface_id =
GetSambaInterfaceIdForMountId(mount_id);
DCHECK_EQ(0, samba_interface_map_.count(samba_interface_id));
samba_interface_map_[samba_interface_id] = mount_id;
}
void MountManager::DeleteSambaInterfaceIdFromSambaInterfaceMap(
int32_t mount_id) {
SambaInterface::SambaInterfaceId samba_interface_id =
GetSambaInterfaceIdForMountId(mount_id);
DCHECK_NE(0, samba_interface_map_.count(samba_interface_id));
samba_interface_map_.erase(samba_interface_id);
}
bool MountManager::ExistsInMounts(const std::string& mount_root) const {
for (auto mount_iter = mounts_.Begin(); mount_iter != mounts_.End();
++mount_iter) {
if (mount_iter->second.mount_root == mount_root) {
return true;
}
}
return false;
}
} // namespace smbprovider