// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "minios/network_manager.h"

#include <base/bind.h>
#include <base/logging.h>
#include <base/strings/stringprintf.h>
#include <brillo/errors/error_codes.h>
#include <brillo/message_loops/message_loop.h>
#include <dbus/shill/dbus-constants.h>

#include "minios/shill_utils.h"

namespace minios {

namespace {
std::string ToString(brillo::Error* error) {
  if (!error)
    return "";
  return base::StringPrintf("code=%s,message=%s", error->GetCode().c_str(),
                            error->GetMessage().c_str());
}
}  // namespace

NetworkManager::NetworkManager(std::unique_ptr<ShillProxyInterface> shill_proxy)
    : shill_proxy_(std::move(shill_proxy)), weak_ptr_factory_(this) {}

void NetworkManager::Connect(const std::string& ssid,
                             const std::string& passphrase) {
  ConnectMapIter iter = connect_map_.find(ssid);
  // Bundle in with the already processing `Connect()`.
  if (iter != connect_map_.end()) {
    return;
  }
  connect_map_[ssid] = ConnectField{.passphrase = passphrase};
  iter = connect_map_.find(ssid);

  shill_proxy_->ManagerRequestScan(
      WifiTechnologyType::WIFI,
      base::Bind(static_cast<void (NetworkManager::*)(ConnectMapIter)>(
                     &NetworkManager::RequestScanSuccess),
                 weak_ptr_factory_.GetWeakPtr(), iter),
      base::Bind(
          static_cast<void (NetworkManager::*)(ConnectMapIter, brillo::Error*)>(
              &NetworkManager::RequestScanError),
          weak_ptr_factory_.GetWeakPtr(), iter));
}

void NetworkManager::RequestScanSuccess(ConnectMapIter iter) {
  LOG(INFO) << "RequestScan success for SSID=" << iter->first;

  // If there is no passphrase, default to no security.
  const auto& security = iter->second.passphrase.empty()
                             ? ToString(WifiSecurityType::NONE)
                             : ToString(WifiSecurityType::PSK);
  const brillo::VariantDictionary properties = {
      // Mode needs to be set from supported station type.
      {shill::kModeProperty, brillo::Any(ToString(WifiStationType::MANAGED))},
      // SSID of the wireless network.
      {shill::kNameProperty, brillo::Any(iter->first)},
      {shill::kSecurityClassProperty, brillo::Any(security)},
      {shill::kTypeProperty, brillo::Any(ToString(WifiTechnologyType::WIFI))},
  };
  shill_proxy_->ManagerFindMatchingService(
      properties,
      base::Bind(&NetworkManager::FindServiceSuccess,
                 weak_ptr_factory_.GetWeakPtr(), iter),
      base::Bind(&NetworkManager::FindServiceError,
                 weak_ptr_factory_.GetWeakPtr(), iter));
}

void NetworkManager::RequestScanError(ConnectMapIter iter,
                                      brillo::Error* error) {
  LOG(ERROR) << "RequestScan failed for SSID=" << iter->first << ": "
             << ToString(error);
  Return(iter, error);
}

void NetworkManager::FindServiceSuccess(ConnectMapIter iter,
                                        const dbus::ObjectPath& service_path) {
  LOG(INFO) << "FindService success for SSID=" << iter->first
            << ": found object path " << service_path.value();
  iter->second.service_path = service_path;
  shill_proxy_->ServiceGetProperties(
      service_path,
      base::Bind(&NetworkManager::GetServiceSuccess,
                 weak_ptr_factory_.GetWeakPtr(), iter),
      base::Bind(&NetworkManager::GetServiceError,
                 weak_ptr_factory_.GetWeakPtr(), iter));
}

void NetworkManager::FindServiceError(ConnectMapIter iter,
                                      brillo::Error* error) {
  LOG(ERROR) << "FindService failed for SSID=" << iter->first << ": "
             << ToString(error);
  Return(iter, error);
}

void NetworkManager::GetServiceSuccess(ConnectMapIter iter,
                                       const brillo::VariantDictionary& dict) {
  LOG(INFO) << "GetService success for SSID=" << iter->first;

  // Check the strength of the service before continuing to connect.
  for (const auto& pr : dict) {
    if (pr.first == shill::kSignalStrengthProperty) {
      const auto& strength =
          brillo::GetVariantValueOrDefault<uint8_t>(dict, pr.first);
      if (strength > 0) {
        const brillo::VariantDictionary properties = {
            {shill::kAutoConnectProperty, brillo::Any(true)},
            {shill::kPassphraseProperty, brillo::Any(iter->second.passphrase)},
        };
        // Set the SSID passphrase and proceed with connecting.
        shill_proxy_->ServiceSetProperties(
            iter->second.service_path, properties,
            base::Bind(&NetworkManager::ConfigureNetworkSuccess,
                       weak_ptr_factory_.GetWeakPtr(), iter),
            base::Bind(&NetworkManager::ConfigureNetworkError,
                       weak_ptr_factory_.GetWeakPtr(), iter));
      } else {
        Return(iter,
               brillo::Error::Create(
                   FROM_HERE, brillo::errors::dbus::kDomain, DBUS_ERROR_FAILED,
                   "Strength is too weak to connect for SSID=" + iter->first)
                   .get());
      }
      return;
    }
  }
  Return(iter, brillo::Error::Create(
                   FROM_HERE, brillo::errors::dbus::kDomain, DBUS_ERROR_FAILED,
                   "Strength is missing for SSID=" + iter->first)
                   .get());
}

void NetworkManager::GetServiceError(ConnectMapIter iter,
                                     brillo::Error* error) {
  LOG(ERROR) << "GetService failed for SSID=" << iter->first << ": "
             << ToString(error);
  Return(iter, error);
}

void NetworkManager::ConfigureNetworkSuccess(ConnectMapIter iter) {
  LOG(INFO) << "ConfigureNetwork success for SSID=" << iter->first;
  ServiceConnect(iter);
}

void NetworkManager::ConfigureNetworkError(ConnectMapIter iter,
                                           brillo::Error* error) {
  LOG(ERROR) << "ConfigureNetwork failed for SSID=" << iter->first << ": "
             << ToString(error);
  Return(iter, error);
}

void NetworkManager::ServiceConnect(ConnectMapIter iter) {
  shill_proxy_->ServiceConnect(
      iter->second.service_path,
      base::Bind(&NetworkManager::ConnectToNetworkSuccess,
                 weak_ptr_factory_.GetWeakPtr(), iter),
      base::Bind(&NetworkManager::ConnectToNetworkError,
                 weak_ptr_factory_.GetWeakPtr(), iter));
}

void NetworkManager::ConnectToNetworkSuccess(ConnectMapIter iter) {
  LOG(INFO) << "ConnectToNetwork success for SSID=" << iter->first
            << " proceeding to verify connection.";
  shill_proxy_->ServiceGetProperties(
      iter->second.service_path,
      base::Bind(&NetworkManager::GetServiceCheckConnectionSuccess,
                 weak_ptr_factory_.GetWeakPtr(), iter),
      base::Bind(&NetworkManager::GetServiceCheckConnectionError,
                 weak_ptr_factory_.GetWeakPtr(), iter));
}

void NetworkManager::ConnectToNetworkError(ConnectMapIter iter,
                                           brillo::Error* error) {
  auto error_code = error->GetCode();
  if (error->GetCode() == shill::kErrorResultInProgress) {
    LOG(INFO) << "ConnectToNetwork failed, but connection is in progress for "
              << "SSID=" << iter->first;
    // Try connecting again until connection is not in progress or gets
    // connected.
    brillo::MessageLoop::current()->PostDelayedTask(
        FROM_HERE,
        base::Bind(&NetworkManager::ServiceConnect,
                   weak_ptr_factory_.GetWeakPtr(), iter),
        base::TimeDelta::FromMilliseconds(kConnectionRetryMsDelay));
  } else if (error_code == shill::kErrorResultAlreadyConnected) {
    LOG(INFO) << "ConnectToNetwork failed, but already connected for SSID="
              << iter->first;
    Return(iter);
  } else {
    LOG(ERROR) << "ConnectToNetwork failed for SSID=" << iter->first << ": "
               << ToString(error);
    Return(iter, error);
  }
}

void NetworkManager::GetServiceCheckConnectionSuccess(
    ConnectMapIter iter, const brillo::VariantDictionary& dict) {
  for (const auto& [property, any] : dict) {
    if (property == shill::kStateProperty) {
      const auto& state = any.Get<std::string>();
      LOG(INFO) << "GetServiceCheckConnection state is " << state;
      if (state == shill::kStateOnline) {
        Return(iter);
      } else if (state == shill::kStateAssociation ||
                 state == shill::kStateConfiguration ||
                 state == shill::kStateReady) {
        brillo::MessageLoop::current()->PostDelayedTask(
            FROM_HERE,
            base::Bind(&NetworkManager::ConnectToNetworkSuccess,
                       weak_ptr_factory_.GetWeakPtr(), iter),
            base::TimeDelta::FromMilliseconds(kCheckConnectionRetryMsDelay));
      } else {
        Return(iter,
               brillo::Error::Create(
                   FROM_HERE, brillo::errors::dbus::kDomain, DBUS_ERROR_FAILED,
                   "Connection failed for SSID=" + iter->first)
                   .get());
      }
      return;
    }
  }
  Return(iter, brillo::Error::Create(
                   FROM_HERE, brillo::errors::dbus::kDomain, DBUS_ERROR_FAILED,
                   "Connection property missing for SSID=" + iter->first)
                   .get());
}

void NetworkManager::GetServiceCheckConnectionError(ConnectMapIter iter,
                                                    brillo::Error* error) {
  LOG(ERROR) << "GetServiceCheckConnection failed for SSID=" << iter->first
             << ": " << ToString(error);
  Return(iter, error);
}

void NetworkManager::Return(ConnectMapIter iter, brillo::Error* error) {
  for (auto& observer : observers_)
    observer.OnConnect(iter->first, error);
  connect_map_.erase(iter);
}

void NetworkManager::GetNetworks() {
  // Bundle in with the already processing `GetNetworks()`.
  if (!get_networks_list_.empty())
    return;

  // `get_networks_list_`'s max size should never exceed a single node.
  GetNetworksListIter iter =
      get_networks_list_.insert(get_networks_list_.end(), GetNetworksField());

  shill_proxy_->ManagerRequestScan(
      WifiTechnologyType::WIFI,
      base::Bind(static_cast<GetNetworksRequestScanSuccessType>(
                     &NetworkManager::RequestScanSuccess),
                 weak_ptr_factory_.GetWeakPtr(), iter),
      base::Bind(static_cast<GetNetworksRequestScanErrorType>(
                     &NetworkManager::RequestScanError),
                 weak_ptr_factory_.GetWeakPtr(), iter));
}

void NetworkManager::RequestScanSuccess(GetNetworksListIter iter) {
  LOG(INFO) << "RequestScan success.";
  shill_proxy_->ManagerGetProperties(
      base::Bind(&NetworkManager::GetGlobalPropertiesSuccess,
                 weak_ptr_factory_.GetWeakPtr(), iter),
      base::Bind(&NetworkManager::GetGlobalPropertiesError,
                 weak_ptr_factory_.GetWeakPtr(), iter));
}

void NetworkManager::RequestScanError(GetNetworksListIter iter,
                                      brillo::Error* error) {
  LOG(ERROR) << "RequestScan failed: " << ToString(error);
  Return(iter, error);
}

void NetworkManager::GetGlobalPropertiesSuccess(
    GetNetworksListIter iter, const brillo::VariantDictionary& dict) {
  LOG(INFO) << "GetGlobalProperties success.";
  for (const auto& pr : dict) {
    if (pr.first == shill::kServicesProperty) {
      auto services =
          brillo::GetVariantValueOrDefault<std::vector<dbus::ObjectPath>>(
              dict, pr.first);
      if (services.empty()) {
        LOG(ERROR) << "No services found.";
        break;
      }
      // Move the list of services to read from.
      iter->service_paths.assign(services.begin(), services.end());
      // Start the iterations over each service asynchronously.
      shill_proxy_->ServiceGetProperties(
          iter->service_paths.back(),
          base::Bind(&NetworkManager::IterateOverServicePropertiesSuccess,
                     weak_ptr_factory_.GetWeakPtr(), iter),
          base::Bind(&NetworkManager::IterateOverServicePropertiesError,
                     weak_ptr_factory_.GetWeakPtr(), iter));
      iter->service_paths.pop_back();
      return;
    }
  }
  Return(iter);
}

void NetworkManager::GetGlobalPropertiesError(GetNetworksListIter iter,
                                              brillo::Error* error) {
  LOG(ERROR) << "GetGlobalProperties failed: " << ToString(error);
  Return(iter, error);
}

void NetworkManager::IterateOverServicePropertiesSuccess(
    GetNetworksListIter iter, const brillo::VariantDictionary& dict) {
  LOG(INFO) << "IterateOverServiceProperties success.";

  // Save the name (SSID) if it's not empty.
  auto name =
      brillo::GetVariantValueOrDefault<std::string>(dict, shill::kNameProperty);
  if (!name.empty())
    iter->networks.push_back(std::move(name));

  // Iterated over all services.
  if (iter->service_paths.empty()) {
    Return(iter);
    return;
  }

  // Iterate over the next service.
  shill_proxy_->ServiceGetProperties(
      iter->service_paths.back(),
      base::Bind(&NetworkManager::IterateOverServicePropertiesSuccess,
                 weak_ptr_factory_.GetWeakPtr(), iter),
      base::Bind(&NetworkManager::IterateOverServicePropertiesError,
                 weak_ptr_factory_.GetWeakPtr(), iter));
  iter->service_paths.pop_back();
}

void NetworkManager::IterateOverServicePropertiesError(GetNetworksListIter iter,
                                                       brillo::Error* error) {
  LOG(ERROR) << "IterateOverServiceProperties failed: " << ToString(error);
  if (!iter->service_paths.empty()) {
    shill_proxy_->ServiceGetProperties(
        iter->service_paths.back(),
        base::Bind(&NetworkManager::IterateOverServicePropertiesSuccess,
                   weak_ptr_factory_.GetWeakPtr(), iter),
        base::Bind(&NetworkManager::IterateOverServicePropertiesError,
                   weak_ptr_factory_.GetWeakPtr(), iter));
    iter->service_paths.pop_back();
    return;
  }
  Return(iter);
}

void NetworkManager::Return(GetNetworksListIter iter, brillo::Error* error) {
  for (auto& observer : observers_)
    observer.OnGetNetworks(iter->networks, error);
  get_networks_list_.clear();
}

}  // namespace minios
