blob: 3952db799933abad3cde11f29d0ae533ce31db30 [file] [log] [blame]
__author__ = "raphtee@google.com (Travis Miller)"
import re, collections, StringIO, sys, unittest
class StubNotFoundError(Exception):
'Raised when god is asked to unstub an attribute that was not stubbed'
pass
class CheckPlaybackError(Exception):
'Raised when mock playback does not match recorded calls.'
pass
class SaveDataAfterCloseStringIO(StringIO.StringIO):
"""Saves the contents in a final_data property when close() is called.
Useful as a mock output file object to test both that the file was
closed and what was written.
Properties:
final_data: Set to the StringIO's getvalue() data when close() is
called. None if close() has not been called.
"""
final_data = None
def close(self):
self.final_data = self.getvalue()
StringIO.StringIO.close(self)
class argument_comparator(object):
def is_satisfied_by(self, parameter):
raise NotImplementedError
class equality_comparator(argument_comparator):
def __init__(self, value):
self.value = value
@staticmethod
def _types_match(arg1, arg2):
if isinstance(arg1, basestring) and isinstance(arg2, basestring):
return True
return type(arg1) == type(arg2)
@classmethod
def _compare(cls, actual_arg, expected_arg):
if isinstance(expected_arg, argument_comparator):
return expected_arg.is_satisfied_by(actual_arg)
if not cls._types_match(expected_arg, actual_arg):
return False
if isinstance(expected_arg, list) or isinstance(expected_arg, tuple):
# recurse on lists/tuples
if len(actual_arg) != len(expected_arg):
return False
for actual_item, expected_item in zip(actual_arg, expected_arg):
if not cls._compare(actual_item, expected_item):
return False
elif isinstance(expected_arg, dict):
# recurse on dicts
if not cls._compare(sorted(actual_arg.keys()),
sorted(expected_arg.keys())):
return False
for key, value in actual_arg.iteritems():
if not cls._compare(value, expected_arg[key]):
return False
elif actual_arg != expected_arg:
return False
return True
def is_satisfied_by(self, parameter):
return self._compare(parameter, self.value)
def __str__(self):
if isinstance(self.value, argument_comparator):
return str(self.value)
return repr(self.value)
class regex_comparator(argument_comparator):
def __init__(self, pattern, flags=0):
self.regex = re.compile(pattern, flags)
def is_satisfied_by(self, parameter):
return self.regex.search(parameter) is not None
def __str__(self):
return self.regex.pattern
class is_string_comparator(argument_comparator):
def is_satisfied_by(self, parameter):
return isinstance(parameter, basestring)
def __str__(self):
return "a string"
class is_instance_comparator(argument_comparator):
def __init__(self, cls):
self.cls = cls
def is_satisfied_by(self, parameter):
return isinstance(parameter, self.cls)
def __str__(self):
return "is a %s" % self.cls
class anything_comparator(argument_comparator):
def is_satisfied_by(self, parameter):
return True
def __str__(self):
return 'anything'
class base_mapping(object):
def __init__(self, symbol, return_obj, *args, **dargs):
self.return_obj = return_obj
self.symbol = symbol
self.args = [equality_comparator(arg) for arg in args]
self.dargs = dict((key, equality_comparator(value))
for key, value in dargs.iteritems())
self.error = None
def match(self, *args, **dargs):
if len(args) != len(self.args) or len(dargs) != len(self.dargs):
return False
for i, expected_arg in enumerate(self.args):
if not expected_arg.is_satisfied_by(args[i]):
return False
# check for incorrect dargs
for key, value in dargs.iteritems():
if key not in self.dargs:
return False
if not self.dargs[key].is_satisfied_by(value):
return False
# check for missing dargs
for key in self.dargs.iterkeys():
if key not in dargs:
return False
return True
def __str__(self):
return _dump_function_call(self.symbol, self.args, self.dargs)
class function_mapping(base_mapping):
def __init__(self, symbol, return_val, *args, **dargs):
super(function_mapping, self).__init__(symbol, return_val, *args,
**dargs)
def and_return(self, return_obj):
self.return_obj = return_obj
def and_raises(self, error):
self.error = error
class function_any_args_mapping(function_mapping):
"""A mock function mapping that doesn't verify its arguments."""
def match(self, *args, **dargs):
return True
class mock_function(object):
def __init__(self, symbol, default_return_val=None,
record=None, playback=None):
self.default_return_val = default_return_val
self.num_calls = 0
self.args = []
self.dargs = []
self.symbol = symbol
self.record = record
self.playback = playback
self.__name__ = symbol
def __call__(self, *args, **dargs):
self.num_calls += 1
self.args.append(args)
self.dargs.append(dargs)
if self.playback:
return self.playback(self.symbol, *args, **dargs)
else:
return self.default_return_val
def expect_call(self, *args, **dargs):
mapping = function_mapping(self.symbol, None, *args, **dargs)
if self.record:
self.record(mapping)
return mapping
def expect_any_call(self):
"""Like expect_call but don't give a hoot what arguments are passed."""
mapping = function_any_args_mapping(self.symbol, None)
if self.record:
self.record(mapping)
return mapping
class mask_function(mock_function):
def __init__(self, symbol, original_function, default_return_val=None,
record=None, playback=None):
super(mask_function, self).__init__(symbol,
default_return_val,
record, playback)
self.original_function = original_function
def run_original_function(self, *args, **dargs):
return self.original_function(*args, **dargs)
class mock_class(object):
def __init__(self, cls, name, default_ret_val=None,
record=None, playback=None):
self.__name = name
self.__record = record
self.__playback = playback
for symbol in dir(cls):
if symbol.startswith("_"):
continue
orig_symbol = getattr(cls, symbol)
if callable(orig_symbol):
f_name = "%s.%s" % (self.__name, symbol)
func = mock_function(f_name, default_ret_val,
self.__record, self.__playback)
setattr(self, symbol, func)
else:
setattr(self, symbol, orig_symbol)
def __repr__(self):
return '<mock_class: %s>' % self.__name
class mock_god(object):
NONEXISTENT_ATTRIBUTE = object()
def __init__(self, debug=False, fail_fast=True, ut=None):
"""
With debug=True, all recorded method calls will be printed as
they happen.
With fail_fast=True, unexpected calls will immediately cause an
exception to be raised. With False, they will be silently recorded and
only reported when check_playback() is called.
"""
self.recording = collections.deque()
self.errors = []
self._stubs = []
self._debug = debug
self._fail_fast = fail_fast
self._ut = ut
def set_fail_fast(self, fail_fast):
self._fail_fast = fail_fast
def create_mock_class_obj(self, cls, name, default_ret_val=None):
record = self.__record_call
playback = self.__method_playback
errors = self.errors
class cls_sub(cls):
cls_count = 0
# overwrite the initializer
def __init__(self, *args, **dargs):
pass
@classmethod
def expect_new(typ, *args, **dargs):
obj = typ.make_new(*args, **dargs)
mapping = base_mapping(name, obj, *args, **dargs)
record(mapping)
return obj
def __new__(typ, *args, **dargs):
return playback(name, *args, **dargs)
@classmethod
def make_new(typ, *args, **dargs):
obj = super(cls_sub, typ).__new__(typ, *args,
**dargs)
typ.cls_count += 1
obj_name = "%s_%s" % (name, typ.cls_count)
for symbol in dir(obj):
if (symbol.startswith("__") and
symbol.endswith("__")):
continue
if isinstance(getattr(typ, symbol, None), property):
continue
orig_symbol = getattr(obj, symbol)
if callable(orig_symbol):
f_name = ("%s.%s" %
(obj_name, symbol))
func = mock_function(f_name,
default_ret_val,
record,
playback)
setattr(obj, symbol, func)
else:
setattr(obj, symbol,
orig_symbol)
return obj
return cls_sub
def create_mock_class(self, cls, name, default_ret_val=None):
"""
Given something that defines a namespace cls (class, object,
module), and a (hopefully unique) name, will create a
mock_class object with that name and that possessess all
the public attributes of cls. default_ret_val sets the
default_ret_val on all methods of the cls mock.
"""
return mock_class(cls, name, default_ret_val,
self.__record_call, self.__method_playback)
def create_mock_function(self, symbol, default_return_val=None):
"""
create a mock_function with name symbol and default return
value of default_ret_val.
"""
return mock_function(symbol, default_return_val,
self.__record_call, self.__method_playback)
def mock_up(self, obj, name, default_ret_val=None):
"""
Given an object (class instance or module) and a registration
name, then replace all its methods with mock function objects
(passing the orignal functions to the mock functions).
"""
for symbol in dir(obj):
if symbol.startswith("__"):
continue
orig_symbol = getattr(obj, symbol)
if callable(orig_symbol):
f_name = "%s.%s" % (name, symbol)
func = mask_function(f_name, orig_symbol,
default_ret_val,
self.__record_call,
self.__method_playback)
setattr(obj, symbol, func)
def stub_with(self, namespace, symbol, new_attribute):
original_attribute = getattr(namespace, symbol,
self.NONEXISTENT_ATTRIBUTE)
# You only want to save the original attribute in cases where it is
# directly associated with the object in question. In cases where
# the attribute is actually inherited via some sort of hierarchy
# you want to delete the stub (restoring the original structure)
attribute_is_inherited = (hasattr(namespace, '__dict__') and
symbol not in namespace.__dict__)
if attribute_is_inherited:
original_attribute = self.NONEXISTENT_ATTRIBUTE
newstub = (namespace, symbol, original_attribute, new_attribute)
self._stubs.append(newstub)
setattr(namespace, symbol, new_attribute)
def stub_function(self, namespace, symbol):
mock_attribute = self.create_mock_function(symbol)
self.stub_with(namespace, symbol, mock_attribute)
def stub_class_method(self, cls, symbol):
mock_attribute = self.create_mock_function(symbol)
self.stub_with(cls, symbol, staticmethod(mock_attribute))
def stub_class(self, namespace, symbol):
attr = getattr(namespace, symbol)
mock_class = self.create_mock_class_obj(attr, symbol)
self.stub_with(namespace, symbol, mock_class)
def stub_function_to_return(self, namespace, symbol, object_to_return):
"""Stub out a function with one that always returns a fixed value.
@param namespace The namespace containing the function to stub out.
@param symbol The attribute within the namespace to stub out.
@param object_to_return The value that the stub should return whenever
it is called.
"""
self.stub_with(namespace, symbol,
lambda *args, **dargs: object_to_return)
def _perform_unstub(self, stub):
namespace, symbol, orig_attr, new_attr = stub
if orig_attr == self.NONEXISTENT_ATTRIBUTE:
delattr(namespace, symbol)
else:
setattr(namespace, symbol, orig_attr)
def unstub(self, namespace, symbol):
for stub in reversed(self._stubs):
if (namespace, symbol) == (stub[0], stub[1]):
self._perform_unstub(stub)
self._stubs.remove(stub)
return
raise StubNotFoundError()
def unstub_all(self):
self._stubs.reverse()
for stub in self._stubs:
self._perform_unstub(stub)
self._stubs = []
def __method_playback(self, symbol, *args, **dargs):
if self._debug:
print >> sys.__stdout__, (' * Mock call: ' +
_dump_function_call(symbol, args, dargs))
if len(self.recording) != 0:
func_call = self.recording[0]
if func_call.symbol != symbol:
msg = ("Unexpected call: %s\nExpected: %s"
% (_dump_function_call(symbol, args, dargs),
func_call))
self._append_error(msg)
return None
if not func_call.match(*args, **dargs):
msg = ("Incorrect call: %s\nExpected: %s"
% (_dump_function_call(symbol, args, dargs),
func_call))
self._append_error(msg)
return None
# this is the expected call so pop it and return
self.recording.popleft()
if func_call.error:
raise func_call.error
else:
return func_call.return_obj
else:
msg = ("unexpected call: %s"
% (_dump_function_call(symbol, args, dargs)))
self._append_error(msg)
return None
def __record_call(self, mapping):
self.recording.append(mapping)
def _append_error(self, error):
if self._debug:
print >> sys.__stdout__, ' *** ' + error
if self._fail_fast:
raise CheckPlaybackError(error)
self.errors.append(error)
def check_playback(self):
"""
Report any errors that were encounterd during calls
to __method_playback().
"""
if len(self.errors) > 0:
if self._debug:
print '\nPlayback errors:'
for error in self.errors:
print >> sys.__stdout__, error
if self._ut:
self._ut.fail('\n'.join(self.errors))
raise CheckPlaybackError
elif len(self.recording) != 0:
errors = []
for func_call in self.recording:
error = "%s not called" % (func_call,)
errors.append(error)
print >> sys.__stdout__, error
if self._ut:
self._ut.fail('\n'.join(errors))
raise CheckPlaybackError
self.recording.clear()
def mock_io(self):
"""Mocks and saves the stdout & stderr output"""
self.orig_stdout = sys.stdout
self.orig_stderr = sys.stderr
self.mock_streams_stdout = StringIO.StringIO('')
self.mock_streams_stderr = StringIO.StringIO('')
sys.stdout = self.mock_streams_stdout
sys.stderr = self.mock_streams_stderr
def unmock_io(self):
"""Restores the stdout & stderr, and returns both
output strings"""
sys.stdout = self.orig_stdout
sys.stderr = self.orig_stderr
values = (self.mock_streams_stdout.getvalue(),
self.mock_streams_stderr.getvalue())
self.mock_streams_stdout.close()
self.mock_streams_stderr.close()
return values
def _arg_to_str(arg):
if isinstance(arg, argument_comparator):
return str(arg)
return repr(arg)
def _dump_function_call(symbol, args, dargs):
arg_vec = []
for arg in args:
arg_vec.append(_arg_to_str(arg))
for key, val in dargs.iteritems():
arg_vec.append("%s=%s" % (key, _arg_to_str(val)))
return "%s(%s)" % (symbol, ', '.join(arg_vec))