blob: 52bf40f4ce413ad079902f735783313e4abb5b69 [file] [edit]
// Copyright 2024 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "modemfwd/flash_task.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <base/files/scoped_temp_dir.h>
#include <base/run_loop.h>
#include <base/test/task_environment.h>
#include <base/time/time.h>
#include <dbus/message.h>
#include <dbus/mock_bus.h>
#include <dbus/mock_object_proxy.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "modemfwd/error.h"
#include "modemfwd/mock_daemon_delegate.h"
#include "modemfwd/mock_journal.h"
#include "modemfwd/mock_metrics.h"
#include "modemfwd/mock_modem.h"
#include "modemfwd/mock_modem_flasher.h"
#include "modemfwd/mock_notification_manager.h"
using testing::_;
using testing::IsNull;
using testing::NotNull;
using testing::Return;
using testing::SetArgPointee;
using testing::WithArg;
namespace modemfwd {
namespace {
constexpr char kDeviceId1[] = "device:id:1";
constexpr char kEquipmentId1[] = "equipment_id_1";
constexpr char kMainFirmware1Version[] = "versionA";
constexpr char kMainFirmware2Path[] = "main_fw_2.fls";
constexpr char kMainFirmware2Version[] = "versionB";
constexpr char kOemFirmware1Version[] = "6000.1";
constexpr char kCarrier1[] = "uuid_1";
constexpr char kCarrier2[] = "uuid_2";
constexpr char kCarrier2Firmware1Path[] = "carrier_2_fw_1.fls";
constexpr char kCarrier2Firmware1Version[] = "4500.15.65";
// Journal entry ID
constexpr char kJournalEntryId[] = "journal-entry";
constexpr char kFirmwareDir[] = "/fake/firmware/dir";
} // namespace
class FlashTaskTest : public ::testing::Test {
public:
FlashTaskTest() {
delegate_ = std::make_unique<MockDelegate>();
journal_ = std::make_unique<MockJournal>();
notification_mgr_ = std::make_unique<MockNotificationManager>();
metrics_ = std::make_unique<MockMetrics>();
bus_ = base::MakeRefCounted<dbus::MockBus>(dbus::Bus::Options());
upstart_object_proxy_ = base::MakeRefCounted<dbus::MockObjectProxy>(
bus_.get(), UpstartJobController::kUpstartServiceName,
dbus::ObjectPath(UpstartJobController::kUpstartPath));
ON_CALL(*upstart_object_proxy_, CallMethodAndBlock(_, _))
.WillByDefault([](dbus::MethodCall* method_call, int /*timeout_ms*/) {
auto resp = dbus::Response::CreateEmpty();
dbus::MessageWriter writer(resp.get());
if (method_call->GetMember() == "GetAllJobs") {
writer.AppendArrayOfObjectPaths(std::vector<dbus::ObjectPath>{
dbus::ObjectPath(UpstartJobController::kHermesJobPath),
dbus::ObjectPath(UpstartJobController::kModemHelperJobPath)});
}
return resp;
});
ON_CALL(*bus_, GetObjectProxy(
UpstartJobController::kUpstartServiceName,
dbus::ObjectPath(UpstartJobController::kUpstartPath)))
.WillByDefault(Return(upstart_object_proxy_.get()));
// We can use the same mock object for all jobs.
job_object_proxy_ = base::MakeRefCounted<dbus::MockObjectProxy>(
bus_.get(), UpstartJobController::kUpstartServiceName,
dbus::ObjectPath("/job"));
ON_CALL(*job_object_proxy_, CallMethodAndBlock(_, _))
.WillByDefault([](dbus::MethodCall* method_call, int /*timeout_ms*/) {
auto resp = dbus::Response::CreateEmpty();
dbus::MessageWriter writer(resp.get());
if (method_call->GetMember() == "Start" ||
method_call->GetMember() == "GetInstance") {
writer.AppendObjectPath(dbus::ObjectPath("/job"));
}
return resp;
});
ON_CALL(*bus_, GetObjectProxy(UpstartJobController::kUpstartServiceName, _))
.WillByDefault(Return(job_object_proxy_.get()));
auto modem_flasher = std::make_unique<MockModemFlasher>();
modem_flasher_ = modem_flasher.get();
async_modem_flasher_ =
base::MakeRefCounted<AsyncModemFlasher>(std::move(modem_flasher));
}
std::unique_ptr<FlashTask> CreateFlashTask() {
return std::make_unique<FlashTask>(delegate_.get(), journal_.get(),
notification_mgr_.get(), metrics_.get(),
bus_, async_modem_flasher_);
}
protected:
scoped_refptr<MockModem> GetDefaultModem() {
auto modem = scoped_refptr<MockModem>(new MockModem);
ON_CALL(*modem, IsPresent()).WillByDefault(Return(true));
ON_CALL(*modem, GetDeviceId()).WillByDefault(Return(kDeviceId1));
ON_CALL(*modem, GetEquipmentId()).WillByDefault(Return(kEquipmentId1));
ON_CALL(*modem, GetCarrierId()).WillByDefault(Return(kCarrier1));
ON_CALL(*modem, GetMainFirmwareVersion())
.WillByDefault(Return(kMainFirmware1Version));
ON_CALL(*modem, GetOemFirmwareVersion())
.WillByDefault(Return(kOemFirmware1Version));
ON_CALL(*modem, GetCarrierFirmwareId()).WillByDefault(Return(""));
ON_CALL(*modem, GetCarrierFirmwareVersion()).WillByDefault(Return(""));
return modem;
}
void SetCarrierFirmwareInfo(MockModem* modem,
const std::string& carrier_id,
const std::string& version) {
ON_CALL(*modem, GetCarrierFirmwareId()).WillByDefault(Return(carrier_id));
ON_CALL(*modem, GetCarrierFirmwareVersion()).WillByDefault(Return(version));
}
std::unique_ptr<FlashConfig> GetConfig(const std::string& carrier_id,
std::vector<FirmwareConfig> fw_cfgs) {
std::map<std::string, std::unique_ptr<FirmwareFile>> files;
std::vector<std::unique_ptr<FirmwareFile>> recovery_files;
base::ScopedTempDir temp_extraction_dir;
EXPECT_TRUE(temp_extraction_dir.CreateUniqueTempDir());
for (const auto& fw_cfg : fw_cfgs) {
auto file = std::make_unique<FirmwareFile>();
file->PrepareFrom(base::FilePath(kFirmwareDir),
temp_extraction_dir.GetPath(),
FirmwareFileInfo(fw_cfg.path.value(), fw_cfg.version));
files[fw_cfg.fw_type] = std::move(file);
}
return std::make_unique<FlashConfig>(
carrier_id, std::move(fw_cfgs), std::move(files),
std::move(recovery_files), std::move(temp_extraction_dir));
}
brillo::ErrorPtr RunTask(FlashTask* task,
scoped_refptr<Modem> modem,
const FlashTask::Options& options) {
base::RunLoop run_loop;
brillo::ErrorPtr error;
EXPECT_CALL(*delegate_, FinishTask(task, _))
.WillOnce(WithArg<1>(
[&error, quit = run_loop.QuitClosure()](brillo::ErrorPtr e) {
error = std::move(e);
std::move(quit).Run();
}));
task->Start(modem, options);
run_loop.Run();
return error;
}
base::test::TaskEnvironment task_environment_;
std::unique_ptr<MockDelegate> delegate_;
std::unique_ptr<MockJournal> journal_;
std::unique_ptr<MockNotificationManager> notification_mgr_;
std::unique_ptr<MockMetrics> metrics_;
scoped_refptr<dbus::MockBus> bus_;
scoped_refptr<dbus::MockObjectProxy> upstart_object_proxy_;
scoped_refptr<dbus::MockObjectProxy> job_object_proxy_;
MockModemFlasher* modem_flasher_;
scoped_refptr<AsyncModemFlasher> async_modem_flasher_;
};
TEST_F(FlashTaskTest, ModemIsBlocked) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(WithArg<1>([](brillo::ErrorPtr* error) {
*error = Error::Create(FROM_HERE, "foo", "foo");
return false;
}));
EXPECT_CALL(*delegate_, NotifyFlashStarting(_)).Times(0);
auto error = RunTask(task.get(), modem, FlashTask::Options{});
EXPECT_NE(error, nullptr);
}
TEST_F(FlashTaskTest, NothingToFlash) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(kCarrier1, {})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _)).Times(0);
EXPECT_CALL(*delegate_, NotifyFlashStarting(_)).Times(0);
auto error = RunTask(task.get(), modem, FlashTask::Options{});
EXPECT_EQ(error, nullptr);
}
TEST_F(FlashTaskTest, BuildConfigReturnedError) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(WithArg<2>([](brillo::ErrorPtr* error) {
*error = Error::Create(FROM_HERE, "foo", "foo");
return nullptr;
}));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _)).Times(0);
EXPECT_CALL(*delegate_, NotifyFlashStarting(_)).Times(0);
auto error = RunTask(task.get(), modem.get(), FlashTask::Options{});
EXPECT_NE(error, nullptr);
}
TEST_F(FlashTaskTest, FlashFailure) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(WithArg<3>([](brillo::ErrorPtr* error) {
*error = Error::Create(FROM_HERE, "foo", "foo");
return false;
}));
EXPECT_CALL(*delegate_, NotifyFlashStarting(_)).Times(1);
auto error = RunTask(task.get(), modem.get(), FlashTask::Options{});
EXPECT_NE(error, nullptr);
}
TEST_F(FlashTaskTest, FlashSuccess) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*delegate_, NotifyFlashStarting(_)).Times(1);
EXPECT_CALL(*metrics_, SendFwFlashTime(_)).Times(1);
// The cleanup callback marks the end of flashing the firmware. The delegate
// will typically run it once the modem comes back up.
EXPECT_CALL(*delegate_, RegisterOnModemReappearanceCallback(_, _))
.WillOnce(WithArg<1>([](base::OnceClosure reg_cb) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, std::move(reg_cb));
}));
auto error = RunTask(task.get(), modem, FlashTask::Options{});
EXPECT_EQ(error, nullptr);
}
TEST_F(FlashTaskTest, WritesToJournal) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*journal_, MarkStartOfFlashingFirmware(_, kDeviceId1, _))
.WillOnce(Return(kJournalEntryId));
EXPECT_CALL(*journal_, MarkEndOfFlashingFirmware(kJournalEntryId)).Times(1);
// The cleanup callback marks the end of flashing the firmware. The delegate
// will typically run it once the modem comes back up.
EXPECT_CALL(*delegate_, RegisterOnModemReappearanceCallback(_, _))
.WillOnce(WithArg<1>([](base::OnceClosure reg_cb) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, std::move(reg_cb));
}));
auto error = RunTask(task.get(), modem, FlashTask::Options{});
EXPECT_EQ(error, nullptr);
}
TEST_F(FlashTaskTest, WritesCarrierToJournal) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kCarrier2Firmware1Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier2, {{kFwCarrier, new_firmware, kCarrier2Firmware1Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*journal_, MarkStartOfFlashingFirmware(_, kDeviceId1, kCarrier2))
.WillOnce(Return(kJournalEntryId));
EXPECT_CALL(*delegate_, RegisterOnModemReappearanceCallback(_, _))
.WillOnce(WithArg<1>([](base::OnceClosure reg_cb) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, std::move(reg_cb));
}));
RunTask(task.get(), modem, FlashTask::Options{});
}
TEST_F(FlashTaskTest, WritesToJournalOnFailure) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(false));
EXPECT_CALL(*journal_, MarkStartOfFlashingFirmware(_, kDeviceId1, _))
.WillOnce(Return(kJournalEntryId));
EXPECT_CALL(*journal_, MarkEndOfFlashingFirmware(kJournalEntryId)).Times(1);
// We should complete inline on failure. No callback should be registered.
EXPECT_CALL(*delegate_, RegisterOnModemReappearanceCallback(_, _)).Times(0);
RunTask(task.get(), modem, FlashTask::Options{});
}
TEST_F(FlashTaskTest, InhibitDuringFlash) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*delegate_, RegisterOnModemReappearanceCallback(_, _))
.WillOnce(WithArg<1>([](base::OnceClosure reg_cb) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, std::move(reg_cb));
}));
EXPECT_CALL(*modem, SetInhibited(true)).Times(1);
EXPECT_CALL(*modem, SetInhibited(false)).Times(1);
RunTask(task.get(), modem, FlashTask::Options{});
}
TEST_F(FlashTaskTest, IgnoreBlock) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillRepeatedly(Return(false));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*delegate_, RegisterOnModemReappearanceCallback(_, _))
.WillOnce(WithArg<1>([](base::OnceClosure reg_cb) {
base::SequencedTaskRunner::GetCurrentDefault()->PostTask(
FROM_HERE, std::move(reg_cb));
}));
auto error = RunTask(task.get(), modem,
FlashTask::Options{.should_always_flash = true});
EXPECT_EQ(error, nullptr);
}
TEST_F(FlashTaskTest, SyncCleanupForStubModem) {
auto task = CreateFlashTask();
auto modem = GetDefaultModem();
base::FilePath new_firmware(kMainFirmware2Path);
// Pretend this is a stub modem.
EXPECT_CALL(*modem, IsPresent()).WillRepeatedly(Return(false));
EXPECT_CALL(*modem_flasher_, ShouldFlash(modem.get(), _))
.WillOnce(Return(true));
EXPECT_CALL(*modem_flasher_, BuildFlashConfig(modem.get(), _, _))
.WillOnce(Return(GetConfig(
kCarrier1, {{kFwMain, new_firmware, kMainFirmware2Version}})));
EXPECT_CALL(*modem_flasher_, RunFlash(modem.get(), _, _, _))
.WillOnce(Return(true));
EXPECT_CALL(*metrics_, SendFwFlashTime(_)).Times(1);
EXPECT_CALL(*journal_, MarkStartOfFlashingFirmware(_, kDeviceId1, _))
.WillOnce(Return(kJournalEntryId));
// We should expect this to run synchronously.
EXPECT_CALL(*journal_, MarkEndOfFlashingFirmware(kJournalEntryId)).Times(1);
RunTask(task.get(), modem, FlashTask::Options{});
}
} // namespace modemfwd