// Copyright 2021 The Chromium Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "chrome/updater/activity_impl.h"

#include <string>
#include <tuple>

#include "base/functional/function_ref.h"
#include "base/strings/strcat.h"
#include "base/win/registry.h"
#include "base/win/windows_types.h"
#include "chrome/updater/test_scope.h"
#include "chrome/updater/updater_scope.h"
#include "chrome/updater/util/win_util.h"
#include "chrome/updater/win/user_info.h"
#include "chrome/updater/win/win_constants.h"
#include "testing/gtest/include/gtest/gtest.h"

namespace updater {
namespace {

struct ReadActiveBitRetval {
  ReadActiveBitRetval() = default;
  ReadActiveBitRetval(DWORD result, bool active_bit)
      : read_result(result), active_bit_set(active_bit) {}
  ReadActiveBitRetval(const ReadActiveBitRetval&) = default;
  ReadActiveBitRetval& operator=(const ReadActiveBitRetval&) = default;

  DWORD read_result = ERROR_FILE_NOT_FOUND;
  bool active_bit_set = false;
};

using ReadActiveBitCallback =
    base::FunctionRef<ReadActiveBitRetval(base::win::RegKey&)>;
using WriteActiveBitCallback =
    base::FunctionRef<DWORD(base::win::RegKey&, bool)>;

struct ReadWriteCallbacks {
  ReadWriteCallbacks() = delete;
  ReadWriteCallbacks(ReadActiveBitCallback read, WriteActiveBitCallback write)
      : read_callback(read), write_callback(write) {}
  ReadWriteCallbacks(const ReadWriteCallbacks&) = default;
  ReadWriteCallbacks& operator=(const ReadWriteCallbacks&) = delete;

  ReadActiveBitCallback read_callback;
  WriteActiveBitCallback write_callback;
};

constexpr wchar_t kDidRun[] = L"dr";
constexpr char kAppId[] = "{6ACB7D4D-E5BA-48b0-85FE-A4051500A1BD}";
constexpr wchar_t kClientStateKeyPath[] =
    CLIENT_STATE_KEY L"{6ACB7D4D-E5BA-48b0-85FE-A4051500A1BD}";

DWORD WriteActiveBitAsString(base::win::RegKey& key, bool value) {
  return key.WriteValue(kDidRun, value ? L"1" : L"0");
}

DWORD WriteActiveBitAsDword(base::win::RegKey& key, bool value) {
  return key.WriteValue(kDidRun, value);
}

ReadActiveBitRetval ReadActiveBitAsString(base::win::RegKey& key) {
  std::wstring did_run_str(L"0");
  const DWORD result = key.ReadValue(kDidRun, &did_run_str);
  return ReadActiveBitRetval(result, did_run_str == L"1");
}

ReadActiveBitRetval ReadActiveBitAsDword(base::win::RegKey& key) {
  DWORD did_run = 0;
  const DWORD result = key.ReadValueDW(kDidRun, &did_run);
  return ReadActiveBitRetval(result, did_run);
}

}  // namespace

class ActivityWinTest : public ::testing::TestWithParam<
                            std::tuple<bool, bool, ReadWriteCallbacks>> {
 protected:
  void SetUp() override {
    std::wstring sid;
    ASSERT_HRESULT_SUCCEEDED(GetProcessUser(nullptr, nullptr, &sid));
    low_integrity_key_path_ =
        base::StrCat({USER_REG_VISTA_LOW_INTEGRITY_HKCU, L"\\", sid, L"\\",
                      kClientStateKeyPath});
    TearDown();

    CreateUserActiveBit(SetUserValue(), WriteActiveBitFn());
    CreateLowIntegrityUserActiveBit(SetLowUserValue(), WriteActiveBitFn());
  }

  void TearDown() override {
    base::win::RegKey(HKEY_CURRENT_USER, L"", Wow6432(DELETE))
        .DeleteKey(kClientStateKeyPath);
    base::win::RegKey(HKEY_CURRENT_USER, L"", Wow6432(DELETE))
        .DeleteKey(low_integrity_key_path_.c_str());
  }

  UpdaterScope GetScope() const { return GetTestScope(); }

  bool SetUserValue() const { return std::get<0>(GetParam()); }

  bool SetLowUserValue() const { return std::get<1>(GetParam()); }

  WriteActiveBitCallback WriteActiveBitFn() const {
    return std::get<2>(GetParam()).write_callback;
  }

  ReadActiveBitCallback ReadActiveBitFn() const {
    return std::get<2>(GetParam()).read_callback;
  }

  void CreateActiveBit(const std::wstring& key_name,
                       bool value,
                       WriteActiveBitCallback callback) const {
    base::win::RegKey key;
    ASSERT_EQ(ERROR_SUCCESS, key.Create(HKEY_CURRENT_USER, key_name.c_str(),
                                        Wow6432(KEY_SET_VALUE)));
    ASSERT_EQ(DWORD{ERROR_SUCCESS}, callback(key, value));
  }

  void DeleteActiveBit(const std::wstring& key_name) const {
    base::win::RegKey key;
    ASSERT_EQ(ERROR_SUCCESS, key.Open(HKEY_CURRENT_USER, key_name.c_str(),
                                      Wow6432(KEY_SET_VALUE)));
    ASSERT_EQ(ERROR_SUCCESS, key.DeleteValue(kDidRun));
  }

  void CheckActiveBit(const std::wstring& key_name,
                      bool expected,
                      ReadActiveBitCallback callback) const {
    base::win::RegKey key;
    ASSERT_EQ(ERROR_SUCCESS, key.Open(HKEY_CURRENT_USER, key_name.c_str(),
                                      Wow6432(KEY_QUERY_VALUE)));

    const ReadActiveBitRetval retval = callback(key);
    ASSERT_EQ(DWORD{ERROR_SUCCESS}, retval.read_result);
    ASSERT_EQ(expected, retval.active_bit_set);
  }

  void CreateUserActiveBit(bool value, WriteActiveBitCallback callback) const {
    CreateActiveBit(kClientStateKeyPath, value, callback);
  }

  void DeleteUserActiveBit() const { DeleteActiveBit(kClientStateKeyPath); }

  void CheckUserActiveBit(bool expected, ReadActiveBitCallback callback) const {
    CheckActiveBit(kClientStateKeyPath, expected, callback);
  }

  void CreateLowIntegrityUserActiveBit(bool value,
                                       WriteActiveBitCallback callback) const {
    CreateActiveBit(low_integrity_key_path_, value, callback);
  }

  void DeleteLowIntegrityUserActiveBit() const {
    DeleteActiveBit(low_integrity_key_path_);
  }

  void CheckLowIntegrityUserActiveBit(bool expected,
                                      ReadActiveBitCallback callback) const {
    CheckActiveBit(low_integrity_key_path_, expected, callback);
  }

 private:
  std::wstring low_integrity_key_path_;
};

TEST_P(ActivityWinTest, GetActiveBit) {
  ASSERT_EQ(SetUserValue() || SetLowUserValue(),
            GetActiveBit(GetScope(), kAppId));

  CheckUserActiveBit(SetUserValue(), ReadActiveBitFn());
  CheckLowIntegrityUserActiveBit(SetLowUserValue(), ReadActiveBitFn());
}

TEST_P(ActivityWinTest, ClearActiveBit) {
  ClearActiveBit(GetScope(), kAppId);

  CheckUserActiveBit(false, &ReadActiveBitAsString);
  CheckLowIntegrityUserActiveBit(false, &ReadActiveBitAsString);
  ASSERT_FALSE(GetActiveBit(GetScope(), kAppId));
}

INSTANTIATE_TEST_SUITE_P(
    UpdaterScopeSetUserValueSetLowUserValueSetAsDword,
    ActivityWinTest,
    ::testing::Combine(
        ::testing::Bool(),
        ::testing::Bool(),
        ::testing::Values(ReadWriteCallbacks(&ReadActiveBitAsString,
                                             &WriteActiveBitAsString),
                          ReadWriteCallbacks(&ReadActiveBitAsDword,
                                             &WriteActiveBitAsDword))));

}  // namespace updater
