api/validate: add each_in validator
Add new validator to allow validating repeated fields, and their
subfieds, have specific values.
BUG=None
TEST=./run_pytest
Change-Id: I2c1645c64c2e69253d9a5b11c3dd50809bbae3b9
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/chromite/+/2595815
Tested-by: Alex Klein <saklein@chromium.org>
Commit-Queue: Alex Klein <saklein@chromium.org>
Reviewed-by: Michael Mortensen <mmortensen@google.com>
diff --git a/api/validate.py b/api/validate.py
index 0f61a39..7bb7b23 100644
--- a/api/validate.py
+++ b/api/validate.py
@@ -36,6 +36,9 @@
Returns:
str|None|int|list|Message|bool - The value of the field.
"""
+ if not field:
+ return message
+
value = message
for part in field.split('.'):
if not isinstance(value, protobuf_message.Message):
@@ -79,7 +82,7 @@
def is_in(field, values):
- """Validate |field| contains |value|.
+ """Validate |field| is an element of |values|.
Args:
field (str): The field being checked. May be . separated nested fields.
@@ -105,6 +108,41 @@
return decorator
+def each_in(field, subfield, values, optional=False):
+ """Validate each |subfield| of the repeated |field| is in |values|.
+
+ Args:
+ field (str): The field being checked. May be . separated nested fields.
+ subfield (str|None): The field in the repeated |field| to validate, or None
+ when |field| is not a repeated message, e.g. enum, scalars.
+ values (list): The possible values field may take.
+ optional (bool): Also allow the field to be empty when True.
+ """
+ assert field
+ assert values
+
+ def decorator(func):
+ @functools.wraps(func)
+ def _is_in(input_proto, output_proto, config, *args, **kwargs):
+ if config.do_validation:
+ members = _value(field, input_proto) or []
+ if not optional and not members:
+ cros_build_lib.Die('The %s field is empty.', field)
+ for member in members:
+ logging.debug('Validating %s.[each].%s is in %r.', field, subfield,
+ values)
+ value = _value(subfield, member)
+ if value not in values:
+ cros_build_lib.Die('%s.[each].%s (%r) must be in %r is required.',
+ field, subfield, value, values)
+
+ return func(input_proto, output_proto, config, *args, **kwargs)
+
+ return _is_in
+
+ return decorator
+
+
# pylint: disable=docstring-misnamed-args
def require(*fields):
"""Verify |fields| have all been set.
diff --git a/api/validate_unittest.py b/api/validate_unittest.py
index 00e57ab..fdc6561 100644
--- a/api/validate_unittest.py
+++ b/api/validate_unittest.py
@@ -18,10 +18,16 @@
from chromite.lib import cros_test_lib
from chromite.lib import osutils
-
assert sys.version_info >= (3, 6), 'This module requires Python 3.6+'
+# These tests test the validators by defining a local `impl` function that
+# has the same parameters as a controller function and the validator being
+# tested. The validators don't care that they aren't actually controller
+# functions, they just need the function to look like one, so it works
+# to pass an arbitrary message; i.e. passing one of the Request messages
+# we'd usually expect in a controller is not required. The validator
+# just needs to be checking one of the fields on the message being used.
class ExistsTest(cros_test_lib.TempDirTestCase, api_config.ApiConfigMixin):
"""Tests for the exists validator."""
@@ -100,6 +106,186 @@
impl(common_pb2.Chroot(), None, self.no_validate_config)
+class EachInTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
+ """Tests for the each_in validator."""
+
+ # Easier access to the enum values.
+ ENUM_FOO = build_api_test_pb2.TEST_ENUM_FOO
+ ENUM_BAR = build_api_test_pb2.TEST_ENUM_BAR
+ ENUM_BAZ = build_api_test_pb2.TEST_ENUM_BAZ
+
+ # pylint: disable=docstring-misnamed-args
+ def _message_request(self, *messages):
+ """Build a request instance, filling out the messages field.
+
+ Args:
+ messages: Each messages data (id, name, flag, enum) as lists. Only
+ requires as many as are set. e.g. _request([1], [2]) will create two
+ messages with only ids set. _request([1, 'name']) will create one with
+ id and name set, but not flag or enum.
+ """
+ request = build_api_test_pb2.TestRequestMessage()
+ for message in messages or []:
+ msg = request.messages.add()
+ try:
+ msg.id = message[0]
+ msg.name = message[1]
+ msg.flag = message[2]
+ except IndexError:
+ pass
+
+ return request
+
+ def _enums_request(self, *enum_values):
+ """Build a request instance, setting the test_enums field."""
+ request = build_api_test_pb2.TestRequestMessage()
+ for value in enum_values:
+ request.test_enums.append(value)
+
+ return request
+
+ def _numbers_request(self, *numbers):
+ """Build a request instance, setting the numbers field."""
+ request = build_api_test_pb2.TestRequestMessage()
+ request.numbers.extend(numbers)
+
+ return request
+
+ def test_message_in(self):
+ """Test valid values."""
+
+ @validate.each_in('messages', 'name', ['foo', 'bar'])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ impl(self._message_request([1, 'foo']), None, self.api_config)
+ impl(self._message_request([1, 'foo'], [2, 'bar']), None, self.api_config)
+
+ def test_enum_in(self):
+ """Test valid enum values."""
+
+ @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ impl(self._enums_request(self.ENUM_FOO), None, self.api_config)
+ impl(self._enums_request(self.ENUM_FOO, self.ENUM_BAR), None,
+ self.api_config)
+
+ def test_scalar_in(self):
+ """Test valid scalar values."""
+
+ @validate.each_in('numbers', None, [1, 2])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ impl(self._numbers_request(1), None, self.api_config)
+ impl(self._numbers_request(1, 2), None, self.api_config)
+
+ def test_message_not_in(self):
+ """Test an invalid value."""
+
+ @validate.each_in('messages', 'name', ['foo', 'bar'])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ # Should be failing on the invalid value.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request([1, 'invalid']), None, self.api_config)
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request([1, 'invalid'], [2, 'invalid']), None,
+ self.api_config)
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request([1, 'foo'], [2, 'invalid']), None,
+ self.api_config)
+
+ def test_enum_not_in(self):
+ """Test an invalid enum value."""
+
+ @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ # Only invalid values.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._enums_request(self.ENUM_BAZ), None, self.api_config)
+ # Mixed valid/invalid values.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._enums_request(self.ENUM_FOO, self.ENUM_BAZ), None,
+ self.api_config)
+
+ def test_scalar_not_in(self):
+ """Test invalid scalar value."""
+
+ @validate.each_in('numbers', None, [1, 2])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ # Only invalid values.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._numbers_request(3), None, self.api_config)
+ # Mixed valid/invalid values.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._numbers_request(1, 2, 3), None, self.api_config)
+
+ def test_not_set(self):
+ """Test an unset value."""
+
+ @validate.each_in('messages', 'name', ['foo', 'bar'])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ # Should be failing without a value set.
+ # No entries in the field.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request(), None, self.api_config)
+ # No value set on lone entry.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request([1]), None, self.api_config)
+ # No value set on multiple entries.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request([1], [2]), None, self.api_config)
+ # Some valid and some invalid entries.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request([1, 'foo'], [2]), None, self.api_config)
+
+ def test_optional(self):
+ """Test optional argument."""
+
+ @validate.each_in('messages', 'name', ['foo', 'bar'], optional=True)
+ @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR],
+ optional=True)
+ @validate.each_in('numbers', None, [1, 2], optional=True)
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ # No entries in the field succeeds.
+ impl(self._message_request(), None, self.api_config)
+
+ # Still fails when entries exist but value unset cases.
+ # No value set on lone entry.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request([1]), None, self.api_config)
+ # No value set on multiple entries.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request([1], [2]), None, self.api_config)
+ # Some valid and some invalid entries.
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._message_request([1, 'foo'], [2]), None, self.api_config)
+
+ def test_skip_validation(self):
+ """Test skipping validation case."""
+
+ @validate.each_in('messages', 'name', ['foo', 'bar'])
+ @validate.each_in('test_enums', None, [self.ENUM_FOO, self.ENUM_BAR])
+ @validate.each_in('numbers', None, [1, 2])
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ # This would otherwise raise an error for multiple invalid fields.
+ impl(self._message_request([1, 'invalid']), None, self.no_validate_config)
+
+
class RequireTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
"""Tests for the require validator."""