| __author__ = "raphtee@google.com (Travis Miller)" |
| |
| |
| import re, collections, StringIO, sys, unittest |
| |
| |
| class StubNotFoundError(Exception): |
| 'Raised when god is asked to unstub an attribute that was not stubbed' |
| pass |
| |
| |
| class CheckPlaybackError(Exception): |
| 'Raised when mock playback does not match recorded calls.' |
| pass |
| |
| |
| class SaveDataAfterCloseStringIO(StringIO.StringIO): |
| """Saves the contents in a final_data property when close() is called. |
| |
| Useful as a mock output file object to test both that the file was |
| closed and what was written. |
| |
| Properties: |
| final_data: Set to the StringIO's getvalue() data when close() is |
| called. None if close() has not been called. |
| """ |
| final_data = None |
| |
| def close(self): |
| self.final_data = self.getvalue() |
| StringIO.StringIO.close(self) |
| |
| |
| |
| class argument_comparator(object): |
| def is_satisfied_by(self, parameter): |
| raise NotImplementedError |
| |
| |
| class equality_comparator(argument_comparator): |
| def __init__(self, value): |
| self.value = value |
| |
| |
| @staticmethod |
| def _types_match(arg1, arg2): |
| if isinstance(arg1, basestring) and isinstance(arg2, basestring): |
| return True |
| return type(arg1) == type(arg2) |
| |
| |
| @classmethod |
| def _compare(cls, actual_arg, expected_arg): |
| if isinstance(expected_arg, argument_comparator): |
| return expected_arg.is_satisfied_by(actual_arg) |
| if not cls._types_match(expected_arg, actual_arg): |
| return False |
| |
| if isinstance(expected_arg, list) or isinstance(expected_arg, tuple): |
| # recurse on lists/tuples |
| if len(actual_arg) != len(expected_arg): |
| return False |
| for actual_item, expected_item in zip(actual_arg, expected_arg): |
| if not cls._compare(actual_item, expected_item): |
| return False |
| elif isinstance(expected_arg, dict): |
| # recurse on dicts |
| if not cls._compare(sorted(actual_arg.keys()), |
| sorted(expected_arg.keys())): |
| return False |
| for key, value in actual_arg.iteritems(): |
| if not cls._compare(value, expected_arg[key]): |
| return False |
| elif actual_arg != expected_arg: |
| return False |
| |
| return True |
| |
| |
| def is_satisfied_by(self, parameter): |
| return self._compare(parameter, self.value) |
| |
| |
| def __str__(self): |
| if isinstance(self.value, argument_comparator): |
| return str(self.value) |
| return repr(self.value) |
| |
| |
| class regex_comparator(argument_comparator): |
| def __init__(self, pattern, flags=0): |
| self.regex = re.compile(pattern, flags) |
| |
| |
| def is_satisfied_by(self, parameter): |
| return self.regex.search(parameter) is not None |
| |
| |
| def __str__(self): |
| return self.regex.pattern |
| |
| |
| class is_string_comparator(argument_comparator): |
| def is_satisfied_by(self, parameter): |
| return isinstance(parameter, basestring) |
| |
| |
| def __str__(self): |
| return "a string" |
| |
| |
| class is_instance_comparator(argument_comparator): |
| def __init__(self, cls): |
| self.cls = cls |
| |
| |
| def is_satisfied_by(self, parameter): |
| return isinstance(parameter, self.cls) |
| |
| |
| def __str__(self): |
| return "is a %s" % self.cls |
| |
| |
| class anything_comparator(argument_comparator): |
| def is_satisfied_by(self, parameter): |
| return True |
| |
| |
| def __str__(self): |
| return 'anything' |
| |
| |
| class base_mapping(object): |
| def __init__(self, symbol, return_obj, *args, **dargs): |
| self.return_obj = return_obj |
| self.symbol = symbol |
| self.args = [equality_comparator(arg) for arg in args] |
| self.dargs = dict((key, equality_comparator(value)) |
| for key, value in dargs.iteritems()) |
| self.error = None |
| |
| |
| def match(self, *args, **dargs): |
| if len(args) != len(self.args) or len(dargs) != len(self.dargs): |
| return False |
| |
| for i, expected_arg in enumerate(self.args): |
| if not expected_arg.is_satisfied_by(args[i]): |
| return False |
| |
| # check for incorrect dargs |
| for key, value in dargs.iteritems(): |
| if key not in self.dargs: |
| return False |
| if not self.dargs[key].is_satisfied_by(value): |
| return False |
| |
| # check for missing dargs |
| for key in self.dargs.iterkeys(): |
| if key not in dargs: |
| return False |
| |
| return True |
| |
| |
| def __str__(self): |
| return _dump_function_call(self.symbol, self.args, self.dargs) |
| |
| |
| class function_mapping(base_mapping): |
| def __init__(self, symbol, return_val, *args, **dargs): |
| super(function_mapping, self).__init__(symbol, return_val, *args, |
| **dargs) |
| |
| |
| def and_return(self, return_obj): |
| self.return_obj = return_obj |
| |
| |
| def and_raises(self, error): |
| self.error = error |
| |
| |
| class function_any_args_mapping(function_mapping): |
| """A mock function mapping that doesn't verify its arguments.""" |
| def match(self, *args, **dargs): |
| return True |
| |
| |
| class mock_function(object): |
| def __init__(self, symbol, default_return_val=None, |
| record=None, playback=None): |
| self.default_return_val = default_return_val |
| self.num_calls = 0 |
| self.args = [] |
| self.dargs = [] |
| self.symbol = symbol |
| self.record = record |
| self.playback = playback |
| self.__name__ = symbol |
| |
| |
| def __call__(self, *args, **dargs): |
| self.num_calls += 1 |
| self.args.append(args) |
| self.dargs.append(dargs) |
| if self.playback: |
| return self.playback(self.symbol, *args, **dargs) |
| else: |
| return self.default_return_val |
| |
| |
| def expect_call(self, *args, **dargs): |
| mapping = function_mapping(self.symbol, None, *args, **dargs) |
| if self.record: |
| self.record(mapping) |
| |
| return mapping |
| |
| |
| def expect_any_call(self): |
| """Like expect_call but don't give a hoot what arguments are passed.""" |
| mapping = function_any_args_mapping(self.symbol, None) |
| if self.record: |
| self.record(mapping) |
| |
| return mapping |
| |
| |
| class mask_function(mock_function): |
| def __init__(self, symbol, original_function, default_return_val=None, |
| record=None, playback=None): |
| super(mask_function, self).__init__(symbol, |
| default_return_val, |
| record, playback) |
| self.original_function = original_function |
| |
| |
| def run_original_function(self, *args, **dargs): |
| return self.original_function(*args, **dargs) |
| |
| |
| class mock_class(object): |
| def __init__(self, cls, name, default_ret_val=None, |
| record=None, playback=None): |
| self.__name = name |
| self.__record = record |
| self.__playback = playback |
| |
| for symbol in dir(cls): |
| if symbol.startswith("_"): |
| continue |
| |
| orig_symbol = getattr(cls, symbol) |
| if callable(orig_symbol): |
| f_name = "%s.%s" % (self.__name, symbol) |
| func = mock_function(f_name, default_ret_val, |
| self.__record, self.__playback) |
| setattr(self, symbol, func) |
| else: |
| setattr(self, symbol, orig_symbol) |
| |
| |
| def __repr__(self): |
| return '<mock_class: %s>' % self.__name |
| |
| |
| class mock_god(object): |
| NONEXISTENT_ATTRIBUTE = object() |
| |
| def __init__(self, debug=False, fail_fast=True, ut=None): |
| """ |
| With debug=True, all recorded method calls will be printed as |
| they happen. |
| With fail_fast=True, unexpected calls will immediately cause an |
| exception to be raised. With False, they will be silently recorded and |
| only reported when check_playback() is called. |
| """ |
| self.recording = collections.deque() |
| self.errors = [] |
| self._stubs = [] |
| self._debug = debug |
| self._fail_fast = fail_fast |
| self._ut = ut |
| |
| |
| def set_fail_fast(self, fail_fast): |
| self._fail_fast = fail_fast |
| |
| |
| def create_mock_class_obj(self, cls, name, default_ret_val=None): |
| record = self.__record_call |
| playback = self.__method_playback |
| errors = self.errors |
| |
| class cls_sub(cls): |
| cls_count = 0 |
| |
| # overwrite the initializer |
| def __init__(self, *args, **dargs): |
| pass |
| |
| |
| @classmethod |
| def expect_new(typ, *args, **dargs): |
| obj = typ.make_new(*args, **dargs) |
| mapping = base_mapping(name, obj, *args, **dargs) |
| record(mapping) |
| return obj |
| |
| |
| def __new__(typ, *args, **dargs): |
| return playback(name, *args, **dargs) |
| |
| |
| @classmethod |
| def make_new(typ, *args, **dargs): |
| obj = super(cls_sub, typ).__new__(typ, *args, |
| **dargs) |
| |
| typ.cls_count += 1 |
| obj_name = "%s_%s" % (name, typ.cls_count) |
| for symbol in dir(obj): |
| if (symbol.startswith("__") and |
| symbol.endswith("__")): |
| continue |
| |
| if isinstance(getattr(typ, symbol, None), property): |
| continue |
| |
| orig_symbol = getattr(obj, symbol) |
| if callable(orig_symbol): |
| f_name = ("%s.%s" % |
| (obj_name, symbol)) |
| func = mock_function(f_name, |
| default_ret_val, |
| record, |
| playback) |
| setattr(obj, symbol, func) |
| else: |
| setattr(obj, symbol, |
| orig_symbol) |
| |
| return obj |
| |
| return cls_sub |
| |
| |
| def create_mock_class(self, cls, name, default_ret_val=None): |
| """ |
| Given something that defines a namespace cls (class, object, |
| module), and a (hopefully unique) name, will create a |
| mock_class object with that name and that possessess all |
| the public attributes of cls. default_ret_val sets the |
| default_ret_val on all methods of the cls mock. |
| """ |
| return mock_class(cls, name, default_ret_val, |
| self.__record_call, self.__method_playback) |
| |
| |
| def create_mock_function(self, symbol, default_return_val=None): |
| """ |
| create a mock_function with name symbol and default return |
| value of default_ret_val. |
| """ |
| return mock_function(symbol, default_return_val, |
| self.__record_call, self.__method_playback) |
| |
| |
| def mock_up(self, obj, name, default_ret_val=None): |
| """ |
| Given an object (class instance or module) and a registration |
| name, then replace all its methods with mock function objects |
| (passing the orignal functions to the mock functions). |
| """ |
| for symbol in dir(obj): |
| if symbol.startswith("__"): |
| continue |
| |
| orig_symbol = getattr(obj, symbol) |
| if callable(orig_symbol): |
| f_name = "%s.%s" % (name, symbol) |
| func = mask_function(f_name, orig_symbol, |
| default_ret_val, |
| self.__record_call, |
| self.__method_playback) |
| setattr(obj, symbol, func) |
| |
| |
| def stub_with(self, namespace, symbol, new_attribute): |
| original_attribute = getattr(namespace, symbol, |
| self.NONEXISTENT_ATTRIBUTE) |
| |
| # You only want to save the original attribute in cases where it is |
| # directly associated with the object in question. In cases where |
| # the attribute is actually inherited via some sort of hierarchy |
| # you want to delete the stub (restoring the original structure) |
| attribute_is_inherited = (hasattr(namespace, '__dict__') and |
| symbol not in namespace.__dict__) |
| if attribute_is_inherited: |
| original_attribute = self.NONEXISTENT_ATTRIBUTE |
| |
| newstub = (namespace, symbol, original_attribute, new_attribute) |
| self._stubs.append(newstub) |
| setattr(namespace, symbol, new_attribute) |
| |
| |
| def stub_function(self, namespace, symbol): |
| mock_attribute = self.create_mock_function(symbol) |
| self.stub_with(namespace, symbol, mock_attribute) |
| |
| |
| def stub_class_method(self, cls, symbol): |
| mock_attribute = self.create_mock_function(symbol) |
| self.stub_with(cls, symbol, staticmethod(mock_attribute)) |
| |
| |
| def stub_class(self, namespace, symbol): |
| attr = getattr(namespace, symbol) |
| mock_class = self.create_mock_class_obj(attr, symbol) |
| self.stub_with(namespace, symbol, mock_class) |
| |
| |
| def stub_function_to_return(self, namespace, symbol, object_to_return): |
| """Stub out a function with one that always returns a fixed value. |
| |
| @param namespace The namespace containing the function to stub out. |
| @param symbol The attribute within the namespace to stub out. |
| @param object_to_return The value that the stub should return whenever |
| it is called. |
| """ |
| self.stub_with(namespace, symbol, |
| lambda *args, **dargs: object_to_return) |
| |
| |
| def _perform_unstub(self, stub): |
| namespace, symbol, orig_attr, new_attr = stub |
| if orig_attr == self.NONEXISTENT_ATTRIBUTE: |
| delattr(namespace, symbol) |
| else: |
| setattr(namespace, symbol, orig_attr) |
| |
| |
| def unstub(self, namespace, symbol): |
| for stub in reversed(self._stubs): |
| if (namespace, symbol) == (stub[0], stub[1]): |
| self._perform_unstub(stub) |
| self._stubs.remove(stub) |
| return |
| |
| raise StubNotFoundError() |
| |
| |
| def unstub_all(self): |
| self._stubs.reverse() |
| for stub in self._stubs: |
| self._perform_unstub(stub) |
| self._stubs = [] |
| |
| |
| def __method_playback(self, symbol, *args, **dargs): |
| if self._debug: |
| print >> sys.__stdout__, (' * Mock call: ' + |
| _dump_function_call(symbol, args, dargs)) |
| |
| if len(self.recording) != 0: |
| func_call = self.recording[0] |
| if func_call.symbol != symbol: |
| msg = ("Unexpected call: %s\nExpected: %s" |
| % (_dump_function_call(symbol, args, dargs), |
| func_call)) |
| self._append_error(msg) |
| return None |
| |
| if not func_call.match(*args, **dargs): |
| msg = ("Incorrect call: %s\nExpected: %s" |
| % (_dump_function_call(symbol, args, dargs), |
| func_call)) |
| self._append_error(msg) |
| return None |
| |
| # this is the expected call so pop it and return |
| self.recording.popleft() |
| if func_call.error: |
| raise func_call.error |
| else: |
| return func_call.return_obj |
| else: |
| msg = ("unexpected call: %s" |
| % (_dump_function_call(symbol, args, dargs))) |
| self._append_error(msg) |
| return None |
| |
| |
| def __record_call(self, mapping): |
| self.recording.append(mapping) |
| |
| |
| def _append_error(self, error): |
| if self._debug: |
| print >> sys.__stdout__, ' *** ' + error |
| if self._fail_fast: |
| raise CheckPlaybackError(error) |
| self.errors.append(error) |
| |
| |
| def check_playback(self): |
| """ |
| Report any errors that were encounterd during calls |
| to __method_playback(). |
| """ |
| if len(self.errors) > 0: |
| if self._debug: |
| print '\nPlayback errors:' |
| for error in self.errors: |
| print >> sys.__stdout__, error |
| |
| if self._ut: |
| self._ut.fail('\n'.join(self.errors)) |
| |
| raise CheckPlaybackError |
| elif len(self.recording) != 0: |
| errors = [] |
| for func_call in self.recording: |
| error = "%s not called" % (func_call,) |
| errors.append(error) |
| print >> sys.__stdout__, error |
| |
| if self._ut: |
| self._ut.fail('\n'.join(errors)) |
| |
| raise CheckPlaybackError |
| self.recording.clear() |
| |
| |
| def mock_io(self): |
| """Mocks and saves the stdout & stderr output""" |
| self.orig_stdout = sys.stdout |
| self.orig_stderr = sys.stderr |
| |
| self.mock_streams_stdout = StringIO.StringIO('') |
| self.mock_streams_stderr = StringIO.StringIO('') |
| |
| sys.stdout = self.mock_streams_stdout |
| sys.stderr = self.mock_streams_stderr |
| |
| |
| def unmock_io(self): |
| """Restores the stdout & stderr, and returns both |
| output strings""" |
| sys.stdout = self.orig_stdout |
| sys.stderr = self.orig_stderr |
| values = (self.mock_streams_stdout.getvalue(), |
| self.mock_streams_stderr.getvalue()) |
| |
| self.mock_streams_stdout.close() |
| self.mock_streams_stderr.close() |
| return values |
| |
| |
| def _arg_to_str(arg): |
| if isinstance(arg, argument_comparator): |
| return str(arg) |
| return repr(arg) |
| |
| |
| def _dump_function_call(symbol, args, dargs): |
| arg_vec = [] |
| for arg in args: |
| arg_vec.append(_arg_to_str(arg)) |
| for key, val in dargs.iteritems(): |
| arg_vec.append("%s=%s" % (key, _arg_to_str(val))) |
| return "%s(%s)" % (symbol, ', '.join(arg_vec)) |