api/validate: Add require_any validator.
BUG=None
TEST=run_pytest
Change-Id: I4f9a6be8f6d99bbaa12c8ec4a1fa170823585511
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/chromite/+/2595826
Tested-by: Alex Klein <saklein@chromium.org>
Reviewed-by: Michael Mortensen <mmortensen@google.com>
Commit-Queue: Alex Klein <saklein@chromium.org>
diff --git a/api/validate.py b/api/validate.py
index d8a075d..0f61a39 100644
--- a/api/validate.py
+++ b/api/validate.py
@@ -132,6 +132,35 @@
return decorator
+# pylint: disable=docstring-misnamed-args
+def require_any(*fields):
+ """Verify at least one of |fields| have been set.
+
+ Args:
+ fields (str): The fields being checked. May be . separated nested fields.
+ """
+ assert fields
+
+ def decorator(func):
+ @functools.wraps(func)
+ def _require(input_proto, output_proto, config, *args, **kwargs):
+ if config.do_validation:
+ for field in fields:
+ logging.debug('Validating %s is set.', field)
+ value = _value(field, input_proto)
+ if value:
+ break
+ else:
+ cros_build_lib.Die('At least one of the following must be set: %s',
+ ', '.join(fields))
+
+ return func(input_proto, output_proto, config, *args, **kwargs)
+
+ return _require
+
+ return decorator
+
+
def require_each(field, subfields, allow_empty=True):
"""Verify |field| each have all of the |subfields| set.
diff --git a/api/validate_unittest.py b/api/validate_unittest.py
index e42b740..00e57ab 100644
--- a/api/validate_unittest.py
+++ b/api/validate_unittest.py
@@ -100,8 +100,8 @@
impl(common_pb2.Chroot(), None, self.no_validate_config)
-class RequiredTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
- """Tests for the required validator."""
+class RequireTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
+ """Tests for the require validator."""
def test_invalid_field(self):
"""Test validator fails when given an unset value."""
@@ -135,7 +135,7 @@
impl(in_proto, None, self.api_config)
def test_mixed(self):
- """Test validator passes when given a set value."""
+ """Test validator fails when given a set value and an unset value."""
@validate.require('path', 'env.use_flags')
def impl(_input_proto, _output_proto, _config):
@@ -154,6 +154,54 @@
impl(common_pb2.Chroot(), None, self.no_validate_config)
+class RequireAnyTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
+ """Tests for the require_any validator."""
+
+ def _get_request(self, mid: int = None, name: str = None, flag: bool = None):
+ """Build a request instance from the given data."""
+ request = build_api_test_pb2.MultiFieldMessage()
+
+ if mid:
+ request.id = mid
+ if name:
+ request.name = name
+ if flag:
+ request.flag = flag
+
+ return request
+
+ def test_invalid_field(self):
+ """Test validator fails when given an invalid field."""
+
+ @validate.require_any('does.not.exist', 'also.invalid')
+ def impl(_input_proto, _output_proto, _config):
+ self.fail('Incorrectly allowed method to execute.')
+
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._get_request(), None, self.api_config)
+
+ def test_not_set(self):
+ """Test validator fails when given unset values."""
+
+ @validate.require_any('id', 'name')
+ def impl(_input_proto, _output_proto, _config):
+ self.fail('Incorrectly allowed method to execute.')
+
+ with self.assertRaises(cros_build_lib.DieSystemExit):
+ impl(self._get_request(flag=True), None, self.api_config)
+
+ def test_set(self):
+ """Test validator passes when given set values."""
+
+ @validate.require_any('id', 'name')
+ def impl(_input_proto, _output_proto, _config):
+ pass
+
+ impl(self._get_request(1), None, self.api_config)
+ impl(self._get_request(name='foo'), None, self.api_config)
+ impl(self._get_request(1, name='foo'), None, self.api_config)
+
+
class RequireEachTest(cros_test_lib.TestCase, api_config.ApiConfigMixin):
"""Tests for the require_each validator."""