stage_build_to_usbkey during setup_usbkey if needed.

Change-Id: Ibc5c6eb3e5bf31ccae16bc4e818290ba5d524ba4
BUG=b:159042133
TEST=utils/unittest_suite.py autotest_lib.server.cros.faft.firmware_test_unittest
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/autotest/+/2264394
Reviewed-by: Brent Peterson <brentpeterson@chromium.org>
Reviewed-by: Garry Wang <xianuowang@chromium.org>
Reviewed-by: Patrick Georgi <pgeorgi@chromium.org>
Commit-Queue: Andrew Luo <aluo@chromium.org>
Tested-by: Andrew Luo <aluo@chromium.org>
diff --git a/server/cros/faft/firmware_test.py b/server/cros/faft/firmware_test.py
index 0fc0114..3c97585 100644
--- a/server/cros/faft/firmware_test.py
+++ b/server/cros/faft/firmware_test.py
@@ -634,6 +634,7 @@
                                   used for recovery boot, like Ctrl-U USB boot.
         """
         if usbkey:
+            self.stage_build_to_usbkey()
             self.assert_test_image_in_usb_disk()
         elif host is None:
             # USB disk is not required for the test. Better to mux it to host.
diff --git a/server/cros/faft/firmware_test_unittest.py b/server/cros/faft/firmware_test_unittest.py
index c71ba29..1ae77275 100644
--- a/server/cros/faft/firmware_test_unittest.py
+++ b/server/cros/faft/firmware_test_unittest.py
@@ -194,6 +194,30 @@
         self.assertFalse(self.test.stage_build_to_usbkey())
         self.test._client.stage_build_to_usb.assert_called_with("dummy_build")
 
+    def test_setup_usbkey(self):
+        self.test._client.host_info_store.get.return_value.build = "dummy_build"
+        self.test._client._servo_host.validate_image_usbkey.return_value = (
+            "another_build")
+        self.test.assert_test_image_in_usb_disk = mock.MagicMock()
+        self.test.set_servo_v4_role_to_snk = mock.MagicMock()
+        self.test.setup_usbkey(usbkey=True)
+        self.test._client.stage_build_to_usb.assert_called_with("dummy_build")
+        self.test.assert_test_image_in_usb_disk.assert_called()
+        self.test.set_servo_v4_role_to_snk.assert_called()
+
+    def test_setup_usbkey_no_stage(self):
+        self.test._client.host_info_store.get.return_value.build = "dummy_build"
+        self.test._client._servo_host.validate_image_usbkey.return_value = (
+            "another_build")
+        self.test.assert_test_image_in_usb_disk = mock.MagicMock()
+        self.test.set_servo_v4_role_to_snk = mock.MagicMock()
+        self.test.servo = mock.MagicMock()
+        self.test.setup_usbkey(usbkey=False)
+        self.test._client.stage_build_to_usb.assert_not_called()
+        self.test.assert_test_image_in_usb_disk.assert_not_called()
+        self.test.servo.switch_usbkey.assert_called_with('host')
+        self.test.set_servo_v4_role_to_snk.assert_not_called()
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/server/site_tests/firmware_FAFTModeTransitions/firmware_FAFTModeTransitions.py b/server/site_tests/firmware_FAFTModeTransitions/firmware_FAFTModeTransitions.py
index f9059d7..eef202f 100644
--- a/server/site_tests/firmware_FAFTModeTransitions/firmware_FAFTModeTransitions.py
+++ b/server/site_tests/firmware_FAFTModeTransitions/firmware_FAFTModeTransitions.py
@@ -40,7 +40,6 @@
         if 'rec' in mode_seq:
             logging.info("Mode sequence contains 'rec', setup USB stick with"
                          " image.")
-            self.stage_build_to_usbkey()
             self.setup_usbkey(usbkey=True)
 
         m1 = mode_seq[0]
diff --git a/server/site_tests/firmware_InvalidUSB/firmware_InvalidUSB.py b/server/site_tests/firmware_InvalidUSB/firmware_InvalidUSB.py
index 06f9d27..5244d1f 100644
--- a/server/site_tests/firmware_InvalidUSB/firmware_InvalidUSB.py
+++ b/server/site_tests/firmware_InvalidUSB/firmware_InvalidUSB.py
@@ -30,6 +30,7 @@
     def initialize(self, host, cmdline_args):
         """Initialize the test"""
         super(firmware_InvalidUSB, self).initialize(host, cmdline_args)
+        self.setup_usbkey(usbkey=True)
         self.servo.switch_usbkey('host')
         usb_dev = self.servo.probe_host_usb_dev()
         self.assert_test_image_in_usb_disk(usb_dev)