// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "shill/portal_detector.h"

#include <string>

#include <base/bind.h>
#include <base/rand_util.h>
#include <base/strings/pattern.h>
#include <base/strings/string_number_conversions.h>
#include <base/strings/string_util.h>
#include <base/strings/stringprintf.h>
#include <chromeos/dbus/service_constants.h>

#include "shill/connection.h"
#include "shill/dns_client.h"
#include "shill/event_dispatcher.h"
#include "shill/logging.h"
#include "shill/metrics.h"
#include "shill/net/ip_address.h"

using base::Bind;
using base::Callback;
using base::StringPrintf;
using std::string;

namespace {

// This keyword gets replaced with a number from the below range.
const char kRandomKeyword[] = "${RAND}";

// This range is determined by the server-side configuration.  See b/63033351
const int kMinRandomHost = 1;
const int kMaxRandomHost = 25;

// If |in| contains the substring |kRandomKeyword|, replace it with a
// random number between |kMinRandomHost| and |kMaxRandomHost| and return
// the newly-mangled string.  Otherwise return an exact copy of |in|.  This
// is used to rotate through alternate hostnames (e.g. alt1..alt25) on
// each portal check, to defeat IP-based blocking.
string RandomizeURL(string url) {
  int alt_host = base::RandInt(kMinRandomHost, kMaxRandomHost);
  base::ReplaceFirstSubstringAfterOffset(&url, 0, kRandomKeyword,
                                         base::IntToString(alt_host));
  return url;
}

}  // namespace

namespace shill {

namespace Logging {
static auto kModuleLogScope = ScopeLogger::kPortal;
static string ObjectID(Connection* c) { return c->interface_name(); }
}

const int PortalDetector::kInitialCheckIntervalSeconds = 3;
const int PortalDetector::kMaxPortalCheckIntervalSeconds = 5 * 60;
const char PortalDetector::kDefaultCheckPortalList[] = "ethernet,wifi,cellular";

const int PortalDetector::kRequestTimeoutSeconds = 10;

const char PortalDetector::kDefaultHttpUrl[] =
    "http://www.gstatic.com/generate_204";
const char PortalDetector::kDefaultHttpsUrl[] =
    "https://www.google.com/generate_204";
const std::vector<string> PortalDetector::kDefaultFallbackHttpUrls{
    "http://www.google.com/gen_204",
    "http://play.googleapis.com/generate_204",
};

PortalDetector::PortalDetector(
    ConnectionRefPtr connection,
    EventDispatcher* dispatcher,
    Metrics* metrics,
    const Callback<void(const PortalDetector::Result&)>& callback)
    : attempt_count_(0),
      attempt_start_time_((struct timeval){0}),
      connection_(connection),
      dispatcher_(dispatcher),
      metrics_(metrics),
      weak_ptr_factory_(this),
      portal_result_callback_(callback),
      time_(Time::GetInstance()),
      is_active_(false) {}

PortalDetector::~PortalDetector() {
  Stop();
}

bool PortalDetector::StartAfterDelay(const PortalDetector::Properties& props,
                                     int delay_seconds) {
  SLOG(connection_.get(), 3) << "In " << __func__;

  if (!StartTrial(props, delay_seconds * 1000)) {
    return false;
  }
  // The attempt_start_time_ is calculated based on the current time and
  // |delay_seconds|.  This is used to determine if a portal detection attempt
  // is in progress.
  UpdateAttemptTime(delay_seconds);
  return true;
}

bool PortalDetector::StartTrial(const Properties& props,
                                int start_delay_milliseconds) {
  SLOG(connection_.get(), 3) << "In " << __func__;

  // This step is rerun on each attempt, but trying it here will allow
  // Start() to abort on any obviously malformed URL strings.
  HttpUrl http_url, https_url;
  if (!http_url.ParseFromString(RandomizeURL(props.http_url_string))) {
    LOG(ERROR) << "Failed to parse URL string: " << props.http_url_string;
    return false;
  }
  if (!https_url.ParseFromString(props.https_url_string)) {
    LOG(ERROR) << "Failed to parse URL string: " << props.https_url_string;
    return false;
  }
  http_url_string_ = props.http_url_string;
  https_url_string_ = props.https_url_string;

  if (http_request_ || https_request_) {
    CleanupTrial();
  } else {
    http_request_ = std::make_unique<HttpRequest>(connection_, dispatcher_);
    https_request_ = std::make_unique<HttpRequest>(connection_, dispatcher_);
  }
  StartTrialAfterDelay(start_delay_milliseconds);
  attempt_count_++;
  return true;
}

void PortalDetector::StartTrialAfterDelay(int start_delay_milliseconds) {
  SLOG(connection_.get(), 4)
      << "In " << __func__ << " delay = " << start_delay_milliseconds << "ms.";
  trial_.Reset(
      Bind(&PortalDetector::StartTrialTask, weak_ptr_factory_.GetWeakPtr()));
  dispatcher_->PostDelayedTask(FROM_HERE, trial_.callback(),
                               start_delay_milliseconds);
}

void PortalDetector::StartTrialTask() {
  base::Callback<void(std::shared_ptr<brillo::http::Response>)>
      http_request_success_callback(
          Bind(&PortalDetector::HttpRequestSuccessCallback,
               weak_ptr_factory_.GetWeakPtr()));
  base::Callback<void(HttpRequest::Result)> http_request_error_callback(
      Bind(&PortalDetector::HttpRequestErrorCallback,
           weak_ptr_factory_.GetWeakPtr()));
  HttpRequest::Result http_result = http_request_->Start(
      RandomizeURL(http_url_string_), http_request_success_callback,
      http_request_error_callback);
  if (http_result != HttpRequest::kResultInProgress) {
    CompleteTrial(PortalDetector::GetPortalResultForRequestResult(http_result));
    return;
  }

  base::Callback<void(std::shared_ptr<brillo::http::Response>)>
      https_request_success_callback(
          Bind(&PortalDetector::HttpsRequestSuccessCallback,
               weak_ptr_factory_.GetWeakPtr()));
  base::Callback<void(HttpRequest::Result)> https_request_error_callback(
      Bind(&PortalDetector::HttpsRequestErrorCallback,
           weak_ptr_factory_.GetWeakPtr()));
  HttpRequest::Result https_result =
      https_request_->Start(https_url_string_, https_request_success_callback,
                            https_request_error_callback);
  if (https_result != HttpRequest::kResultInProgress) {
    Result trial_result = GetPortalResultForRequestResult(https_result);
    LOG(ERROR) << connection_->interface_name()
               << StringPrintf(
                      " HTTPS probe start failed phase==%s, status==%s, "
                      "attempt count==%d",
                      PhaseToString(trial_result.phase).c_str(),
                      StatusToString(trial_result.status).c_str(),
                      attempt_count_);
  }
  is_active_ = true;

  trial_timeout_.Reset(
      Bind(&PortalDetector::TimeoutTrialTask, weak_ptr_factory_.GetWeakPtr()));
  dispatcher_->PostDelayedTask(FROM_HERE, trial_timeout_.callback(),
                               kRequestTimeoutSeconds * 1000);
}

bool PortalDetector::IsActive() {
  return is_active_;
}

void PortalDetector::CompleteTrial(Result result) {
  SLOG(connection_.get(), 3) << StringPrintf(
      "Trial completed with phase==%s, status==%s, attempt count==%d",
      PhaseToString(result.phase).c_str(),
      StatusToString(result.status).c_str(), attempt_count_);
  CompleteAttempt(result);
}

void PortalDetector::CleanupTrial() {
  trial_timeout_.Cancel();

  http_result_.reset();
  https_result_.reset();
  if (http_request_)
    http_request_->Stop();
  if (https_request_)
    https_request_->Stop();

  is_active_ = false;
}

void PortalDetector::TimeoutTrialTask() {
  LOG(ERROR) << connection_->interface_name()
             << " Trial request timed out, attempt count==" << attempt_count_;
  CompleteTrial(Result(Phase::kUnknown, Status::kTimeout));
}

void PortalDetector::Stop() {
  SLOG(connection_.get(), 3) << "In " << __func__;

  attempt_count_ = 0;
  if (!http_request_ && !https_request_)
    return;

  CleanupTrial();
  http_request_.reset();
  https_request_.reset();
}

void PortalDetector::CompleteRequest() {
  if (https_result_ && http_result_) {
    metrics_->NotifyPortalDetectionMultiProbeResult(*http_result_,
                                                    *https_result_);
    CompleteTrial(*http_result_.get());
  }
}

void PortalDetector::HttpRequestSuccessCallback(
    std::shared_ptr<brillo::http::Response> response) {
  // TODO(matthewmwang): check for 0 length data as well
  int status_code = response->GetStatusCode();
  if (status_code == brillo::http::status_code::NoContent) {
    http_result_ = std::make_unique<Result>(Phase::kContent, Status::kSuccess);
  } else if (status_code == brillo::http::status_code::Redirect) {
    http_result_ = std::make_unique<Result>(Phase::kContent, Status::kRedirect);
  } else {
    http_result_ = std::make_unique<Result>(Phase::kContent, Status::kFailure);
  }
  CompleteRequest();
}

void PortalDetector::HttpsRequestSuccessCallback(
    std::shared_ptr<brillo::http::Response> response) {
  int status_code = response->GetStatusCode();
  if (status_code == brillo::http::status_code::NoContent) {
    // HTTPS probe success, probably no portal
    https_result_ = std::make_unique<Result>(Phase::kContent, Status::kSuccess);
    LOG(INFO) << connection_->interface_name()
              << " HTTPS probe succeeded, probably no portal, attempt count=="
              << attempt_count_;
  } else {
    // HTTPS probe didn't get 204, inconclusive
    https_result_ = std::make_unique<Result>(Phase::kContent, Status::kFailure);
    LOG(ERROR) << connection_->interface_name()
               << " HTTPS probe returned with status code " << status_code
               << ". Portal detection inconclusive, attempt count=="
               << attempt_count_;
  }
  CompleteRequest();
}

void PortalDetector::HttpRequestErrorCallback(HttpRequest::Result result) {
  http_result_ =
      std::make_unique<Result>(GetPortalResultForRequestResult(result));
  CompleteRequest();
}

void PortalDetector::HttpsRequestErrorCallback(HttpRequest::Result result) {
  https_result_ =
      std::make_unique<Result>(GetPortalResultForRequestResult(result));
  LOG(INFO) << connection_->interface_name()
            << " HTTPS probe failed with phase=="
            << PortalDetector::PhaseToString(https_result_.get()->phase)
            << ", status=="
            << PortalDetector::StatusToString(https_result_.get()->status);
  CompleteRequest();
}

bool PortalDetector::IsInProgress() {
  return is_active_;
}

void PortalDetector::CompleteAttempt(PortalDetector::Result trial_result) {
  LOG(INFO) << connection_->interface_name()
            << StringPrintf(
                   " Portal detection completed attempt %d with "
                   "phase==%s, status==%s",
                   attempt_count_,
                   PortalDetector::PhaseToString(trial_result.phase).c_str(),
                   PortalDetector::StatusToString(trial_result.status).c_str());

  trial_result.num_attempts = attempt_count_;
  CleanupTrial();
  portal_result_callback_.Run(trial_result);
}

void PortalDetector::UpdateAttemptTime(int delay_seconds) {
  time_->GetTimeMonotonic(&attempt_start_time_);
  struct timeval delay_timeval = { delay_seconds, 0 };
  timeradd(&attempt_start_time_, &delay_timeval, &attempt_start_time_);
}

int PortalDetector::AdjustStartDelay(int init_delay_seconds) {
  int next_attempt_delay_seconds = 0;
  if (attempt_count_ > 0) {
    struct timeval now, elapsed_time;
    time_->GetTimeMonotonic(&now);
    timersub(&now, &attempt_start_time_, &elapsed_time);
    SLOG(connection_.get(), 4) << "Elapsed time from previous attempt is "
                               << elapsed_time.tv_sec << " seconds.";
    if (elapsed_time.tv_sec < init_delay_seconds) {
      next_attempt_delay_seconds = init_delay_seconds - elapsed_time.tv_sec;
    }
  } else {
    LOG(FATAL) << "AdjustStartDelay in PortalDetector called without "
                  "previous attempts";
  }
  SLOG(connection_.get(), 3)
      << "Adjusting trial start delay from " << init_delay_seconds
      << " seconds to " << next_attempt_delay_seconds << " seconds.";
  return next_attempt_delay_seconds;
}

// static
const string PortalDetector::PhaseToString(Phase phase) {
  switch (phase) {
    case Phase::kConnection:
      return kPortalDetectionPhaseConnection;
    case Phase::kDNS:
      return kPortalDetectionPhaseDns;
    case Phase::kHTTP:
      return kPortalDetectionPhaseHttp;
    case Phase::kContent:
      return kPortalDetectionPhaseContent;
    case Phase::kUnknown:
    default:
      return kPortalDetectionPhaseUnknown;
  }
}

// static
const string PortalDetector::StatusToString(Status status) {
  switch (status) {
    case Status::kSuccess:
      return kPortalDetectionStatusSuccess;
    case Status::kTimeout:
      return kPortalDetectionStatusTimeout;
    case Status::kFailure:
    default:
      return kPortalDetectionStatusFailure;
  }
}

PortalDetector::Result PortalDetector::GetPortalResultForRequestResult(
    HttpRequest::Result result) {
  switch (result) {
    case HttpRequest::kResultSuccess:
      // The request completed without receiving the expected payload.
      return Result(Phase::kContent, Status::kFailure);
    case HttpRequest::kResultDNSFailure:
      return Result(Phase::kDNS, Status::kFailure);
    case HttpRequest::kResultDNSTimeout:
      return Result(Phase::kDNS, Status::kTimeout);
    case HttpRequest::kResultConnectionFailure:
      return Result(Phase::kConnection, Status::kFailure);
    case HttpRequest::kResultHTTPFailure:
      return Result(Phase::kHTTP, Status::kFailure);
    case HttpRequest::kResultHTTPTimeout:
      return Result(Phase::kHTTP, Status::kTimeout);
    case HttpRequest::kResultInvalidInput:
    case HttpRequest::kResultUnknown:
    default:
      return Result(Phase::kUnknown, Status::kFailure);
  }
}

}  // namespace shill
