blob: 577bfbe2b53c0994d8ba84347850ce27d142da02 [file] [log] [blame] [edit]
// Copyright 2024 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "modemfwd/flash_task.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <base/run_loop.h>
#include <base/test/task_environment.h>
#include <base/time/time.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "modemfwd/mock_daemon_delegate.h"
#include "modemfwd/mock_journal.h"
#include "modemfwd/mock_metrics.h"
#include "modemfwd/mock_modem.h"
#include "modemfwd/mock_modem_flasher.h"
#include "modemfwd/mock_notification_manager.h"
using testing::_;
using testing::Return;
namespace modemfwd {
namespace {
constexpr char kDeviceId1[] = "device:id:1";
constexpr char kEquipmentId1[] = "equipment_id_1";
constexpr char kMainFirmware1Version[] = "versionA";
constexpr char kMainFirmware2Path[] = "main_fw_2.fls";
constexpr char kMainFirmware2Version[] = "versionB";
constexpr char kOemFirmware1Version[] = "6000.1";
constexpr char kCarrier1[] = "uuid_1";
constexpr char kCarrier2[] = "uuid_2";
constexpr char kCarrier2Firmware1Path[] = "carrier_2_fw_1.fls";
constexpr char kCarrier2Firmware1Version[] = "4500.15.65";
// Journal entry ID
constexpr char kJournalEntryId[] = "journal-entry";
constexpr char kFirmwareDir[] = "/fake/firmware/dir";
} // namespace
class FlashTaskTest : public ::testing::Test {
public:
FlashTaskTest() {
delegate_ = std::make_unique<MockDelegate>();
journal_ = std::make_unique<MockJournal>();
notification_mgr_ = std::make_unique<MockNotificationManager>();
metrics_ = std::make_unique<MockMetrics>();
modem_flasher_ = std::make_unique<MockModemFlasher>();
}
std::unique_ptr<FlashTask> CreateFlashTask() {
return std::make_unique<FlashTask>(delegate_.get(), journal_.get(),
notification_mgr_.get(), metrics_.get(),
modem_flasher_.get());
}
protected:
std::unique_ptr<MockModem> GetDefaultModem() {
auto modem = std::make_unique<MockModem>();
ON_CALL(*modem, GetDeviceId()).WillByDefault(Return(kDeviceId1));
ON_CALL(*modem, GetEquipmentId()).WillByDefault(Return(kEquipmentId1));
ON_CALL(*modem, GetCarrierId()).WillByDefault(Return(kCarrier1));
ON_CALL(*modem, GetMainFirmwareVersion())
.WillByDefault(Return(kMainFirmware1Version));
ON_CALL(*modem, GetOemFirmwareVersion())
.WillByDefault(Return(kOemFirmware1Version));
ON_CALL(*modem, GetCarrierFirmwareId()).WillByDefault(Return(""));
ON_CALL(*modem, GetCarrierFirmwareVersion()).WillByDefault(Return(""));
return modem;
}
void SetCarrierFirmwareInfo(MockModem* modem,
const std::string& carrier_id,
const std::string& version) {
ON_CALL(*modem, GetCarrierFirmwareId()).WillByDefault(Return(carrier_id));
ON_CALL(*modem, GetCarrierFirmwareVersion()).WillByDefault(Return(version));
}
std::unique_ptr<FlashConfig> GetConfig(const std::string& carrier_id,
std::vector<FirmwareConfig> fw_cfgs) {
std::map<std::string, std::unique_ptr<FirmwareFile>> files;
for (const auto& fw_cfg : fw_cfgs) {
auto file = std::make_unique<FirmwareFile>();
file->PrepareFrom(base::FilePath(kFirmwareDir),
FirmwareFileInfo(fw_cfg.path.value(), fw_cfg.version));
files[fw_cfg.fw_type] = std::move(file);
}
return std::make_unique<FlashConfig>(carrier_id, std::move(fw_cfgs),
std::move(files));
}
std::unique_ptr<MockDelegate> delegate_;
std::unique_ptr<MockJournal> journal_;
std::unique_ptr<MockNotificationManager> notification_mgr_;
std::unique_ptr<MockMetrics> metrics_;
std::unique_ptr<MockModemFlasher> modem_flasher_;
};
TEST_F(FlashTaskTest, ModemIsBlocked) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(false));
brillo::ErrorPtr err;
EXPECT_FALSE(task->Start(modem.get(), FlashTask::Options{}, &err));
}
TEST_F(FlashTaskTest, NothingToFlash) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(kCarrier1, {})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _)).Times(0);
brillo::ErrorPtr err;
EXPECT_TRUE(task->Start(modem.get(), FlashTask::Options{}, &err));
}
TEST_F(FlashTaskTest, BuildConfigReturnedError) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(std::unique_ptr<FlashConfig>()));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _)).Times(0);
brillo::ErrorPtr err;
EXPECT_FALSE(task->Start(modem.get(), FlashTask::Options{}, &err));
}
TEST_F(FlashTaskTest, FlashFailure) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(false));
brillo::ErrorPtr err;
EXPECT_FALSE(task->Start(modem.get(), FlashTask::Options{}, &err));
}
TEST_F(FlashTaskTest, FlashSuccess) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*metrics_, SendFwFlashTime(_)).Times(1);
brillo::ErrorPtr err;
EXPECT_TRUE(task->Start(modem.get(), FlashTask::Options{}, &err));
EXPECT_EQ(err.get(), nullptr);
}
TEST_F(FlashTaskTest, WritesToJournal) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*journal_, MarkStartOfFlashingFirmware(_, kDeviceId1, _))
.WillOnce(Return(kJournalEntryId));
// The cleanup callback marks the end of flashing the firmware. The delegate
// will typically run it once the modem comes back up.
base::OnceClosure cb;
EXPECT_CALL(*delegate_, RegisterOnModemReappearanceCallback(_, _))
.WillOnce([&cb](const std::string& /*equipment_id*/,
base::OnceClosure reg_cb) { cb = std::move(reg_cb); });
brillo::ErrorPtr err;
EXPECT_TRUE(task->Start(modem.get(), FlashTask::Options{}, &err));
EXPECT_EQ(err.get(), nullptr);
// Ensure the journal entry is closed by the registered callback.
EXPECT_CALL(*journal_, MarkEndOfFlashingFirmware(kJournalEntryId)).Times(1);
EXPECT_FALSE(cb.is_null());
std::move(cb).Run();
}
TEST_F(FlashTaskTest, WritesCarrierToJournal) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kCarrier2Firmware1Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier2, {{kFwCarrier, new_firmware, kCarrier2Firmware1Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*journal_, MarkStartOfFlashingFirmware(_, kDeviceId1, kCarrier2))
.WillOnce(Return(kJournalEntryId));
brillo::ErrorPtr err;
EXPECT_TRUE(task->Start(modem.get(), FlashTask::Options{}, &err));
EXPECT_EQ(err.get(), nullptr);
}
TEST_F(FlashTaskTest, WritesToJournalOnFailure) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(false));
EXPECT_CALL(*journal_, MarkStartOfFlashingFirmware(_, kDeviceId1, _))
.WillOnce(Return(kJournalEntryId));
EXPECT_CALL(*journal_, MarkEndOfFlashingFirmware(kJournalEntryId)).Times(1);
// We should complete inline on failure. No callback should be registered.
EXPECT_CALL(*delegate_, RegisterOnModemReappearanceCallback(_, _)).Times(0);
brillo::ErrorPtr err;
EXPECT_FALSE(task->Start(modem.get(), FlashTask::Options{}, &err));
}
TEST_F(FlashTaskTest, InhibitDuringFlash) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*modem, SetInhibited(true)).Times(1);
EXPECT_CALL(*modem, SetInhibited(false)).Times(1);
brillo::ErrorPtr err;
EXPECT_TRUE(task->Start(modem.get(), FlashTask::Options{}, &err));
EXPECT_EQ(err.get(), nullptr);
}
TEST_F(FlashTaskTest, IgnoreBlock) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillRepeatedly(Return(false));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
brillo::ErrorPtr err;
EXPECT_TRUE(task->Start(
modem.get(), FlashTask::Options{.should_always_flash = true}, &err));
EXPECT_EQ(err.get(), nullptr);
}
} // namespace modemfwd