// Copyright 2014 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "base/bind.h"
#include "base/location.h"
#include "base/run_loop.h"
#include "base/single_thread_task_runner.h"
#include "base/threading/thread_task_runner_handle.h"
#include "chrome/browser/extensions/extension_install_checker.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace extensions {

namespace {

const BlacklistState kBlacklistStateError = BLACKLISTED_MALWARE;
const char kDummyRequirementsError[] = "Requirements error";
const char kDummyPolicyError[] = "Cannot install extension";

}  // namespace

// Stubs most of the checks since we are interested in validating the logic in
// the install checker. This class implements a synchronous version of all
// checks.
class ExtensionInstallCheckerForTest : public ExtensionInstallChecker {
 public:
  ExtensionInstallCheckerForTest(int enabled_checks, bool fail_fast)
      : ExtensionInstallChecker(nullptr, nullptr, enabled_checks, fail_fast),
        requirements_check_called_(false),
        blacklist_check_called_(false),
        policy_check_called_(false),
        blacklist_state_(NOT_BLACKLISTED) {}

  ~ExtensionInstallCheckerForTest() override {}

  void set_requirements_error(const std::string& error) {
    requirements_error_ = error;
  }
  void set_policy_check_error(const std::string& error) {
    policy_check_error_ = error;
  }
  void set_blacklist_state(BlacklistState state) { blacklist_state_ = state; }

  bool requirements_check_called() const { return requirements_check_called_; }
  bool blacklist_check_called() const { return blacklist_check_called_; }
  bool policy_check_called() const { return policy_check_called_; }

  void MockCheckRequirements() {
    if (!is_running())
      return;
    std::vector<std::string> errors;
    if (!requirements_error_.empty())
      errors.push_back(requirements_error_);
    OnRequirementsCheckDone(errors);
  }

  void MockCheckBlacklistState() {
    if (!is_running())
      return;
    OnBlacklistStateCheckDone(blacklist_state_);
  }

 protected:
  void CheckRequirements() override {
    requirements_check_called_ = true;
    MockCheckRequirements();
  }

  void CheckManagementPolicy() override {
    policy_check_called_ = true;
    OnManagementPolicyCheckDone(policy_check_error_.empty(),
                                policy_check_error_);
  }

  void CheckBlacklistState() override {
    blacklist_check_called_ = true;
    MockCheckBlacklistState();
  }

  bool requirements_check_called_;
  bool blacklist_check_called_;
  bool policy_check_called_;

  // Dummy errors for testing.
  std::string requirements_error_;
  std::string policy_check_error_;
  BlacklistState blacklist_state_;
};

// This class implements asynchronous mocks of the requirements and blacklist
// checks.
class ExtensionInstallCheckerAsync : public ExtensionInstallCheckerForTest {
 public:
  ExtensionInstallCheckerAsync(int enabled_checks, bool fail_fast)
      : ExtensionInstallCheckerForTest(enabled_checks, fail_fast) {}

 protected:
  void CheckRequirements() override {
    requirements_check_called_ = true;

    base::ThreadTaskRunnerHandle::Get()->PostTask(
        FROM_HERE,
        base::Bind(&ExtensionInstallCheckerForTest::MockCheckRequirements,
                   base::Unretained(this)));
  }

  void CheckBlacklistState() override {
    blacklist_check_called_ = true;

    base::ThreadTaskRunnerHandle::Get()->PostTask(
        FROM_HERE,
        base::Bind(&ExtensionInstallCheckerForTest::MockCheckBlacklistState,
                   base::Unretained(this)));
  }
};

class CheckObserver {
 public:
  CheckObserver() : result_(0), call_count_(0) {}

  int result() const { return result_; }
  int call_count() const { return call_count_; }

  void OnChecksComplete(int checks_failed) {
    result_ = checks_failed;
    ++call_count_;
  }

  void Wait() {
    if (call_count_)
      return;

    base::RunLoop().RunUntilIdle();
  }

 private:
  int result_;
  int call_count_;
};

class ExtensionInstallCheckerTest : public testing::Test {
 public:
  ExtensionInstallCheckerTest() {}
  ~ExtensionInstallCheckerTest() override {}

 protected:
  void SetAllErrors(ExtensionInstallCheckerForTest* checker) {
    checker->set_blacklist_state(kBlacklistStateError);
    checker->set_policy_check_error(kDummyPolicyError);
    checker->set_requirements_error(kDummyRequirementsError);
  }

  void ValidateExpectedCalls(int call_mask,
                             const ExtensionInstallCheckerForTest& checker) {
    bool expect_blacklist_checked =
        (call_mask & ExtensionInstallChecker::CHECK_BLACKLIST) != 0;
    bool expect_requirements_checked =
        (call_mask & ExtensionInstallChecker::CHECK_REQUIREMENTS) != 0;
    bool expect_policy_checked =
        (call_mask & ExtensionInstallChecker::CHECK_MANAGEMENT_POLICY) != 0;
    EXPECT_EQ(expect_blacklist_checked, checker.blacklist_check_called());
    EXPECT_EQ(expect_policy_checked, checker.policy_check_called());
    EXPECT_EQ(expect_requirements_checked, checker.requirements_check_called());
  }

  void ExpectRequirementsPass(const ExtensionInstallCheckerForTest& checker) {
    EXPECT_TRUE(checker.requirement_errors().empty());
  }

  void ExpectRequirementsError(const char* expected_error,
                               const ExtensionInstallCheckerForTest& checker) {
    EXPECT_FALSE(checker.requirement_errors().empty());
    EXPECT_EQ(std::string(expected_error),
              checker.requirement_errors().front());
  }

  void ExpectRequirementsError(const ExtensionInstallCheckerForTest& checker) {
    ExpectRequirementsError(kDummyRequirementsError, checker);
  }

  void ExpectBlacklistPass(const ExtensionInstallCheckerForTest& checker) {
    EXPECT_EQ(NOT_BLACKLISTED, checker.blacklist_state());
  }

  void ExpectBlacklistError(const ExtensionInstallCheckerForTest& checker) {
    EXPECT_EQ(kBlacklistStateError, checker.blacklist_state());
  }

  void ExpectPolicyPass(const ExtensionInstallCheckerForTest& checker) {
    EXPECT_TRUE(checker.policy_allows_load());
    EXPECT_TRUE(checker.policy_error().empty());
  }

  void ExpectPolicyError(const char* expected_error,
                         const ExtensionInstallCheckerForTest& checker) {
    EXPECT_FALSE(checker.policy_allows_load());
    EXPECT_FALSE(checker.policy_error().empty());
    EXPECT_EQ(std::string(expected_error), checker.policy_error());
  }

  void ExpectPolicyError(const ExtensionInstallCheckerForTest& checker) {
    ExpectPolicyError(kDummyPolicyError, checker);
  }

  void RunChecker(ExtensionInstallCheckerForTest* checker,
                  int expected_checks_run,
                  int expected_result) {
    CheckObserver observer;
    checker->Start(base::Bind(&CheckObserver::OnChecksComplete,
                              base::Unretained(&observer)));
    observer.Wait();

    EXPECT_FALSE(checker->is_running());
    EXPECT_EQ(expected_result, observer.result());
    EXPECT_EQ(1, observer.call_count());
    ValidateExpectedCalls(expected_checks_run, *checker);
  }

  void DoRunAllChecksPass(ExtensionInstallCheckerForTest* checker) {
    RunChecker(checker,
               ExtensionInstallChecker::CHECK_ALL,
               0);

    ExpectRequirementsPass(*checker);
    ExpectPolicyPass(*checker);
    ExpectBlacklistPass(*checker);
  }

  void DoRunAllChecksFail(ExtensionInstallCheckerForTest* checker) {
    SetAllErrors(checker);
    RunChecker(checker,
               ExtensionInstallChecker::CHECK_ALL,
               ExtensionInstallChecker::CHECK_ALL);

    ExpectRequirementsError(*checker);
    ExpectPolicyError(*checker);
    ExpectBlacklistError(*checker);
  }

  void DoRunSubsetOfChecks(int checks_to_run) {
    ExtensionInstallCheckerForTest sync_checker(checks_to_run,
                                                /*fail_fast=*/false);
    ExtensionInstallCheckerAsync async_checker(checks_to_run,
                                               /*fail_fast=*/false);
    ExtensionInstallCheckerForTest* checkers[] = {
        &sync_checker, &async_checker,
    };

    for (auto* checker : checkers) {
      SetAllErrors(checker);
      RunChecker(checker, checks_to_run, checks_to_run);

      if (checks_to_run & ExtensionInstallChecker::CHECK_REQUIREMENTS)
        ExpectRequirementsError(*checker);
      else
        ExpectRequirementsPass(*checker);

      if (checks_to_run & ExtensionInstallChecker::CHECK_MANAGEMENT_POLICY)
        ExpectPolicyError(*checker);
      else
        ExpectPolicyPass(*checker);

      if (checks_to_run & ExtensionInstallChecker::CHECK_BLACKLIST)
        ExpectBlacklistError(*checker);
      else
        ExpectBlacklistPass(*checker);
    }
  }

 private:
  // A message loop is required for the asynchronous tests.
  base::MessageLoop message_loop;
};

// Test the case where all tests pass.
TEST_F(ExtensionInstallCheckerTest, AllSucceeded) {
  ExtensionInstallCheckerForTest sync_checker(
      ExtensionInstallChecker::CHECK_ALL, /*fail_fast=*/false);
  DoRunAllChecksPass(&sync_checker);

  ExtensionInstallCheckerAsync async_checker(ExtensionInstallChecker::CHECK_ALL,
                                             /*fail_fast=*/false);
  DoRunAllChecksPass(&async_checker);
}

// Test the case where all tests fail.
TEST_F(ExtensionInstallCheckerTest, AllFailed) {
  ExtensionInstallCheckerForTest sync_checker(
      ExtensionInstallChecker::CHECK_ALL, /*fail_fast=*/false);
  DoRunAllChecksFail(&sync_checker);

  ExtensionInstallCheckerAsync async_checker(ExtensionInstallChecker::CHECK_ALL,
                                             /*fail_fast=*/false);
  DoRunAllChecksFail(&async_checker);
}

// Test running only a subset of tests.
TEST_F(ExtensionInstallCheckerTest, RunSubsetOfChecks) {
  DoRunSubsetOfChecks(ExtensionInstallChecker::CHECK_MANAGEMENT_POLICY |
                      ExtensionInstallChecker::CHECK_REQUIREMENTS);
  DoRunSubsetOfChecks(ExtensionInstallChecker::CHECK_BLACKLIST |
                      ExtensionInstallChecker::CHECK_REQUIREMENTS);
  DoRunSubsetOfChecks(ExtensionInstallChecker::CHECK_BLACKLIST);
}

// Test fail fast with synchronous callbacks.
TEST_F(ExtensionInstallCheckerTest, FailFastSync) {
  // This test assumes some internal knowledge of the implementation - that
  // the policy check runs first.
  {
    ExtensionInstallCheckerForTest checker(ExtensionInstallChecker::CHECK_ALL,
                                           /*fail_fast=*/true);
    SetAllErrors(&checker);
    RunChecker(&checker, ExtensionInstallChecker::CHECK_MANAGEMENT_POLICY,
               ExtensionInstallChecker::CHECK_MANAGEMENT_POLICY);

    ExpectRequirementsPass(checker);
    ExpectPolicyError(checker);
    ExpectBlacklistPass(checker);
  }

  {
    ExtensionInstallCheckerForTest checker(
        ExtensionInstallChecker::CHECK_REQUIREMENTS |
            ExtensionInstallChecker::CHECK_BLACKLIST,
        /*fail_fast=*/true);
    SetAllErrors(&checker);
    RunChecker(&checker, ExtensionInstallChecker::CHECK_REQUIREMENTS,
               ExtensionInstallChecker::CHECK_REQUIREMENTS);

    ExpectRequirementsError(checker);
    ExpectPolicyPass(checker);
    ExpectBlacklistPass(checker);
  }
}

// Test fail fast with asynchronous callbacks.
TEST_F(ExtensionInstallCheckerTest, FailFastAsync) {
  // This test assumes some internal knowledge of the implementation - that
  // the requirements check runs before the blacklist check. Both checks should
  // be called, but the requirements check callback arrives first and the
  // blacklist result will be discarded.
  ExtensionInstallCheckerAsync checker(ExtensionInstallChecker::CHECK_ALL,
                                       /*fail_fast=*/true);

  SetAllErrors(&checker);

  // The policy check is synchronous and needs to pass for the other tests to
  // run.
  checker.set_policy_check_error(std::string());

  RunChecker(&checker,
             ExtensionInstallChecker::CHECK_ALL,
             ExtensionInstallChecker::CHECK_REQUIREMENTS);

  ExpectRequirementsError(checker);
  ExpectPolicyPass(checker);
  ExpectBlacklistPass(checker);
}

}  // namespace extensions
