// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "federated/federated_session.h"

#include <string>
#include <utility>

#include <base/logging.h>
#include <base/notreached.h>

#include "federated/device_status_monitor.h"
#include "federated/example_database.h"
#include "federated/federated_metadata.h"
#include "federated/protos/cros_events.pb.h"

namespace federated {

namespace {
#if USE_LOCAL_FEDERATED_SERVER
constexpr base::TimeDelta kDefaultRetryWindow = base::Seconds(30);
constexpr base::TimeDelta kMinimalRetryWindow = base::Seconds(10);
#else
// TODO(alanlxl): discussion required about the default window.
constexpr base::TimeDelta kDefaultRetryWindow = base::Seconds(60 * 5);

// To avoid spam, retry window should not be shorter than kMinimalRetryWindow.
constexpr base::TimeDelta kMinimalRetryWindow = base::Seconds(60);
#endif

// TODO(alanlxl): Just dummpy impl.
void LogCrosEvent(const fcp::client::CrosEvent& cros_event) {
  LOG(INFO) << "In LogCrosEvent, model_id is " << cros_event.model_id();

  if (cros_event.has_eligibility_eval_checkin()) {
    LOG(INFO) << "cros_event has_eligibility_eval_checkin";
  } else if (cros_event.has_eligibility_eval_plan_received()) {
    LOG(INFO) << "cros_event has_eligibility_eval_plan_received";
  } else if (cros_event.has_eligibility_eval_not_configured()) {
    LOG(INFO) << "cros_event.has_eligibility_eval_not_configured";
  } else if (cros_event.has_eligibility_eval_rejected()) {
    LOG(INFO) << "cros_event.has_eligibility_eval_rejected";
  } else if (cros_event.has_checkin()) {
    LOG(INFO) << "cros_event.has_checkin";
  } else if (cros_event.has_checkin_finished()) {
    LOG(INFO) << "cros_event.has_checkin_finished";
  } else if (cros_event.has_rejected()) {
    LOG(INFO) << "cros_event.has_rejected";
  } else if (cros_event.has_report_started()) {
    LOG(INFO) << "cros_event.has_report_started";
  } else if (cros_event.has_report_finished()) {
    LOG(INFO) << "cros_event.has_report_finished";
  } else if (cros_event.has_plan_execution_started()) {
    LOG(INFO) << "cros_event.has_plan_execution_started";
  } else if (cros_event.has_epoch_started()) {
    LOG(INFO) << "cros_event.has_epoch_started";
  } else if (cros_event.has_tensorflow_error()) {
    LOG(INFO) << "cros_event.has_tensorflow_error";
    DVLOG(1) << cros_event.DebugString();
  } else if (cros_event.has_io_error()) {
    LOG(INFO) << "cros_event.has_io_error";
  } else if (cros_event.has_example_selector_error()) {
    LOG(INFO) << "cros_event.has_example_selector_error";
  } else if (cros_event.has_interruption()) {
    LOG(INFO) << "cros_event.has_interruption";
  } else if (cros_event.has_epoch_completed()) {
    LOG(INFO) << "cros_event.has_epoch_completed";
  } else if (cros_event.has_stats()) {
    LOG(INFO) << "cros_event.has_stats";
  } else if (cros_event.has_plan_completed()) {
    LOG(INFO) << "cros_event.has_plan_completed";
  } else {
    LOG(INFO) << "cros_event doesn't have any event log";
  }
}

// TODO(alanlxl): Just dummpy impl.
void LogCrosSecAggEvent(const fcp::client::CrosSecAggEvent& cros_secagg_event) {
  LOG(INFO) << "In LogCrosSecAggEvent, session_id is "
            << cros_secagg_event.execution_session_id();

  if (cros_secagg_event.has_state_transition())
    LOG(INFO) << "cros_secagg_event.has_state_transition";
  else if (cros_secagg_event.has_error())
    LOG(INFO) << "cros_secagg_event.has_error";
  else if (cros_secagg_event.has_abort())
    LOG(INFO) << "cros_secagg_event.has_abort";
  else
    LOG(INFO) << "cros_secagg_event doesn't have any event log";
}

}  // namespace

FederatedSession::Context::Context(
    const DeviceStatusMonitor* const device_status_monitor,
    ExampleDatabase::Iterator&& example_iterator)
    : device_status_monitor_(device_status_monitor),
      example_iterator_(std::move(example_iterator)) {}

FederatedSession::Context::~Context() = default;

bool FederatedSession::Context::GetNextExample(const char** const data,
                                               int* const size,
                                               bool* const end,
                                               void* const context) {
  if (context == nullptr)
    return false;

  const absl::StatusOr<ExampleRecord> record =
      static_cast<FederatedSession::Context*>(context)
          ->example_iterator_.Next();

  if (absl::IsInvalidArgument(record.status())) {
    return false;
  }

  if (record.ok()) {
    *end = false;
    *size = record->serialized_example.size();
    char* const str_data = new char[*size];
    record->serialized_example.copy(str_data, *size);
    *data = str_data;
  } else {
    DCHECK(absl::IsOutOfRange(record.status()));
    *end = true;
  }

  return true;
}

void FederatedSession::Context::FreeExample(const char* const data,
                                            void* const context) {
  delete[] data;
}

bool FederatedSession::Context::TrainingConditionsSatisfied(
    void* const context) {
  if (context == nullptr)
    return false;

  return static_cast<FederatedSession::Context*>(context)
      ->device_status_monitor_->TrainingConditionsSatisfied();
}

void FederatedSession::Context::PublishEvent(const char* const event,
                                             const int size,
                                             void* const context) {
  if (context == nullptr) {
    LOG(ERROR) << "PublishEvent gets nullptr context.";
    return;
  }

  fcp::client::CrosEventLog event_log;
  if (!event_log.ParseFromArray(event, size)) {
    LOG(ERROR) << "Failed to parse event_log.";
    return;
  }

  if (event_log.has_event()) {
    LogCrosEvent(event_log.event());
  } else if (event_log.has_secagg_event()) {
    LogCrosSecAggEvent(event_log.secagg_event());
  } else {
    LOG(ERROR) << "event_log has no content";
  }
}

FederatedSession::FederatedSession(
    const FlRunPlanFn run_plan,
    const FlFreeRunPlanResultFn free_run_plan_result,
    const std::string& service_uri,
    const std::string& api_key,
    const ClientConfigMetadata client_config,
    const DeviceStatusMonitor* const device_status_monitor)
    : run_plan_(run_plan),
      free_run_plan_result_(free_run_plan_result),
      service_uri_(service_uri),
      api_key_(api_key),
      client_config_(client_config),
      next_retry_delay_(kDefaultRetryWindow),
      device_status_monitor_(device_status_monitor) {}

FederatedSession::~FederatedSession() = default;

void FederatedSession::RunPlan(ExampleDatabase::Iterator&& example_iterator) {
  FederatedSession::Context context(device_status_monitor_,
                                    std::move(example_iterator));

  const FlTaskEnvironment env = {
      &FederatedSession::Context::GetNextExample,
      &FederatedSession::Context::FreeExample,
      &FederatedSession::Context::TrainingConditionsSatisfied,
      &FederatedSession::Context::PublishEvent,
      client_config_.base_dir.c_str(),
      &context};

  FlRunPlanResult result = (*run_plan_)(
      env, service_uri_.c_str(), api_key_.c_str(), client_config_.name.c_str(),
      client_config_.retry_token.c_str());

  // TODO(alanlxl): maybe log the event to UMA
  if (result.status == CONTRIBUTED || result.status == REJECTED_BY_SERVER) {
    DVLOG(1) << "result.status = " << result.status;
    DVLOG(1) << "result.retry_token = " << result.retry_token;
    DVLOG(1) << "result.delay_usecs = " << result.delay_usecs;
    client_config_.retry_token = std::string(result.retry_token);
    next_retry_delay_ = base::Microseconds(result.delay_usecs);

    // TODO(alanlxl): result.delay_usecs may be 0 when setup is wrong, now I set
    // next_retry_delay_ to kMinimalRetryWindow to avoid spam, consider stopping
    // retry in this case because it's very likely to fail again.
    if (next_retry_delay_ < kMinimalRetryWindow)
      next_retry_delay_ = kMinimalRetryWindow;

  } else {
    DVLOG(1) << "Failed to checkin with the servce, result.status = "
             << result.status;
    next_retry_delay_ = kDefaultRetryWindow;
  }

  (*free_run_plan_result_)(result);
}

void FederatedSession::ResetRetryDelay() {
  next_retry_delay_ = kDefaultRetryWindow;
}

std::string FederatedSession::GetSessionName() const {
  return client_config_.name;
}

}  // namespace federated
