api/validate: add each_in validator

Add new validator to allow validating repeated fields, and their
subfieds, have specific values.

BUG=None
TEST=./run_pytest

Change-Id: I2c1645c64c2e69253d9a5b11c3dd50809bbae3b9
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/chromite/+/2595815
Tested-by: Alex Klein <saklein@chromium.org>
Commit-Queue: Alex Klein <saklein@chromium.org>
Reviewed-by: Michael Mortensen <mmortensen@google.com>
diff --git a/api/validate.py b/api/validate.py
index 0f61a39..7bb7b23 100644
--- a/api/validate.py
+++ b/api/validate.py
@@ -36,6 +36,9 @@
   Returns:
     str|None|int|list|Message|bool - The value of the field.
   """
+  if not field:
+    return message
+
   value = message
   for part in field.split('.'):
     if not isinstance(value, protobuf_message.Message):
@@ -79,7 +82,7 @@
 
 
 def is_in(field, values):
-  """Validate |field| contains |value|.
+  """Validate |field| is an element of |values|.
 
   Args:
     field (str): The field being checked. May be . separated nested fields.
@@ -105,6 +108,41 @@
   return decorator
 
 
+def each_in(field, subfield, values, optional=False):
+  """Validate each |subfield| of the repeated |field| is in |values|.
+
+  Args:
+    field (str): The field being checked. May be . separated nested fields.
+    subfield (str|None): The field in the repeated |field| to validate, or None
+      when |field| is not a repeated message, e.g. enum, scalars.
+    values (list): The possible values field may take.
+    optional (bool): Also allow the field to be empty when True.
+  """
+  assert field
+  assert values
+
+  def decorator(func):
+    @functools.wraps(func)
+    def _is_in(input_proto, output_proto, config, *args, **kwargs):
+      if config.do_validation:
+        members = _value(field, input_proto) or []
+        if not optional and not members:
+          cros_build_lib.Die('The %s field is empty.', field)
+        for member in members:
+          logging.debug('Validating %s.[each].%s is in %r.', field, subfield,
+                        values)
+          value = _value(subfield, member)
+          if value not in values:
+            cros_build_lib.Die('%s.[each].%s (%r) must be in %r is required.',
+                               field, subfield, value, values)
+
+      return func(input_proto, output_proto, config, *args, **kwargs)
+
+    return _is_in
+
+  return decorator
+
+
 # pylint: disable=docstring-misnamed-args
 def require(*fields):
   """Verify |fields| have all been set.
diff --git a/api/validate_unittest.py b/api/validate_unittest.py
index 00e57ab..fdc6561 100644
--- a/api/validate_unittest.py
+++ b/api/validate_unittest.py
@@ -18,10 +18,16 @@
 from chromite.lib import cros_test_lib
 from chromite.lib import osutils
 
-
 assert sys.version_info >= (3, 6), 'This module requires Python 3.6+'
 
 
+# These tests test the validators by defining a local `impl` function that
+# has the same parameters as a controller function and the validator being
+# tested. The validators don't care that they aren't actually controller
+# functions, they just need the function to look like one, so it works
+# to pass an arbitrary message; i.e. passing one of the Request messages
+# we'd usually expect in a controller is not required. The validator
+# just needs to be checking one of the fields on the message being used.
 class ExistsTest(cros_test_lib.TempDirTestCase, api_config.ApiConfigMixin):
   """Tests for the exists validator."""
 
@@ -100,6 +106,186 @@
     impl(common_pb2.Chroot(), None, self.no_validate_config)
 
 
+class EachInTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
+  """Tests for the each_in validator."""
+
+  # Easier access to the enum values.
+  ENUM_FOO = build_api_test_pb2.TEST_ENUM_FOO
+  ENUM_BAR = build_api_test_pb2.TEST_ENUM_BAR
+  ENUM_BAZ = build_api_test_pb2.TEST_ENUM_BAZ
+
+  # pylint: disable=docstring-misnamed-args
+  def _message_request(self, *messages):
+    """Build a request instance, filling out the messages field.
+
+    Args:
+      messages: Each messages data (id, name, flag, enum) as lists. Only
+        requires as many as are set. e.g. _request([1], [2]) will create two
+        messages with only ids set. _request([1, 'name']) will create one with
+        id and name set, but not flag or enum.
+    """
+    request = build_api_test_pb2.TestRequestMessage()
+    for message in messages or []:
+      msg = request.messages.add()
+      try:
+        msg.id = message[0]
+        msg.name = message[1]
+        msg.flag = message[2]
+      except IndexError:
+        pass
+
+    return request
+
+  def _enums_request(self, *enum_values):
+    """Build a request instance, setting the test_enums field."""
+    request = build_api_test_pb2.TestRequestMessage()
+    for value in enum_values:
+      request.test_enums.append(value)
+
+    return request
+
+  def _numbers_request(self, *numbers):
+    """Build a request instance, setting the numbers field."""
+    request = build_api_test_pb2.TestRequestMessage()
+    request.numbers.extend(numbers)
+
+    return request
+
+  def test_message_in(self):
+    """Test valid values."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    impl(self._message_request([1, 'foo']), None, self.api_config)
+    impl(self._message_request([1, 'foo'], [2, 'bar']), None, self.api_config)
+
+  def test_enum_in(self):
+    """Test valid enum values."""
+
+    @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    impl(self._enums_request(self.ENUM_FOO), None, self.api_config)
+    impl(self._enums_request(self.ENUM_FOO, self.ENUM_BAR), None,
+         self.api_config)
+
+  def test_scalar_in(self):
+    """Test valid scalar values."""
+
+    @validate.each_in('numbers', None, [1, 2])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    impl(self._numbers_request(1), None, self.api_config)
+    impl(self._numbers_request(1, 2), None, self.api_config)
+
+  def test_message_not_in(self):
+    """Test an invalid value."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # Should be failing on the invalid value.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'invalid']), None, self.api_config)
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'invalid'], [2, 'invalid']), None,
+           self.api_config)
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'foo'], [2, 'invalid']), None,
+           self.api_config)
+
+  def test_enum_not_in(self):
+    """Test an invalid enum value."""
+
+    @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # Only invalid values.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._enums_request(self.ENUM_BAZ), None, self.api_config)
+    # Mixed valid/invalid values.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._enums_request(self.ENUM_FOO, self.ENUM_BAZ), None,
+           self.api_config)
+
+  def test_scalar_not_in(self):
+    """Test invalid scalar value."""
+
+    @validate.each_in('numbers', None, [1, 2])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # Only invalid values.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._numbers_request(3), None, self.api_config)
+    # Mixed valid/invalid values.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._numbers_request(1, 2, 3), None, self.api_config)
+
+  def test_not_set(self):
+    """Test an unset value."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # Should be failing without a value set.
+    # No entries in the field.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request(), None, self.api_config)
+    # No value set on lone entry.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1]), None, self.api_config)
+    # No value set on multiple entries.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1], [2]), None, self.api_config)
+    # Some valid and some invalid entries.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'foo'], [2]), None, self.api_config)
+
+  def test_optional(self):
+    """Test optional argument."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'], optional=True)
+    @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR],
+                      optional=True)
+    @validate.each_in('numbers', None, [1, 2], optional=True)
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # No entries in the field succeeds.
+    impl(self._message_request(), None, self.api_config)
+
+    # Still fails when entries exist but value unset cases.
+    # No value set on lone entry.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1]), None, self.api_config)
+    # No value set on multiple entries.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1], [2]), None, self.api_config)
+    # Some valid and some invalid entries.
+    with self.assertRaises(cros_build_lib.DieSystemExit):
+      impl(self._message_request([1, 'foo'], [2]), None, self.api_config)
+
+  def test_skip_validation(self):
+    """Test skipping validation case."""
+
+    @validate.each_in('messages', 'name', ['foo', 'bar'])
+    @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR])
+    @validate.each_in('numbers', None, [1, 2])
+    def impl(_input_proto, _output_proto, _config):
+      pass
+
+    # This would otherwise raise an error for multiple invalid fields.
+    impl(self._message_request([1, 'invalid']), None, self.no_validate_config)
+
+
 class RequireTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
   """Tests for the require validator."""