blob: 26e5c67bf47c72bdbe46cf16a7ab9cc3f4284ddf [file] [log] [blame]
from .core import Adapter, AdaptationError, Pass
from .lib import int_to_bin, bin_to_int, swap_bytes
from .lib import FlagsContainer, HexString
from .lib.py3compat import BytesIO, decodebytes
#===============================================================================
# exceptions
#===============================================================================
class BitIntegerError(AdaptationError):
__slots__ = []
class MappingError(AdaptationError):
__slots__ = []
class ConstError(AdaptationError):
__slots__ = []
class ValidationError(AdaptationError):
__slots__ = []
class PaddingError(AdaptationError):
__slots__ = []
#===============================================================================
# adapters
#===============================================================================
class BitIntegerAdapter(Adapter):
"""
Adapter for bit-integers (converts bitstrings to integers, and vice versa).
See BitField.
Parameters:
* subcon - the subcon to adapt
* width - the size of the subcon, in bits
* swapped - whether to swap byte order (little endian/big endian).
default is False (big endian)
* signed - whether the value is signed (two's complement). the default
is False (unsigned)
* bytesize - number of bits per byte, used for byte-swapping (if swapped).
default is 8.
"""
__slots__ = ["width", "swapped", "signed", "bytesize"]
def __init__(self, subcon, width, swapped = False, signed = False,
bytesize = 8):
Adapter.__init__(self, subcon)
self.width = width
self.swapped = swapped
self.signed = signed
self.bytesize = bytesize
def _encode(self, obj, context):
if obj < 0 and not self.signed:
raise BitIntegerError("object is negative, but field is not signed",
obj)
obj2 = int_to_bin(obj, width = self.width)
if self.swapped:
obj2 = swap_bytes(obj2, bytesize = self.bytesize)
return obj2
def _decode(self, obj, context):
if self.swapped:
obj = swap_bytes(obj, bytesize = self.bytesize)
return bin_to_int(obj, signed = self.signed)
class MappingAdapter(Adapter):
"""
Adapter that maps objects to other objects.
See SymmetricMapping and Enum.
Parameters:
* subcon - the subcon to map
* decoding - the decoding (parsing) mapping (a dict)
* encoding - the encoding (building) mapping (a dict)
* decdefault - the default return value when the object is not found
in the decoding mapping. if no object is given, an exception is raised.
if `Pass` is used, the unmapped object will be passed as-is
* encdefault - the default return value when the object is not found
in the encoding mapping. if no object is given, an exception is raised.
if `Pass` is used, the unmapped object will be passed as-is
"""
__slots__ = ["encoding", "decoding", "encdefault", "decdefault"]
def __init__(self, subcon, decoding, encoding,
decdefault = NotImplemented, encdefault = NotImplemented):
Adapter.__init__(self, subcon)
self.decoding = decoding
self.encoding = encoding
self.decdefault = decdefault
self.encdefault = encdefault
def _encode(self, obj, context):
try:
return self.encoding[obj]
except (KeyError, TypeError):
if self.encdefault is NotImplemented:
raise MappingError("no encoding mapping for %r [%s]" % (
obj, self.subcon.name))
if self.encdefault is Pass:
return obj
return self.encdefault
def _decode(self, obj, context):
try:
return self.decoding[obj]
except (KeyError, TypeError):
if self.decdefault is NotImplemented:
raise MappingError("no decoding mapping for %r [%s]" % (
obj, self.subcon.name))
if self.decdefault is Pass:
return obj
return self.decdefault
class FlagsAdapter(Adapter):
"""
Adapter for flag fields. Each flag is extracted from the number, resulting
in a FlagsContainer object. Not intended for direct usage.
See FlagsEnum.
Parameters
* subcon - the subcon to extract
* flags - a dictionary mapping flag-names to their value
"""
__slots__ = ["flags"]
def __init__(self, subcon, flags):
Adapter.__init__(self, subcon)
self.flags = flags
def _encode(self, obj, context):
flags = 0
for name, value in self.flags.items():
if getattr(obj, name, False):
flags |= value
return flags
def _decode(self, obj, context):
obj2 = FlagsContainer()
for name, value in self.flags.items():
setattr(obj2, name, bool(obj & value))
return obj2
class StringAdapter(Adapter):
"""
Adapter for strings. Converts a sequence of characters into a python
string, and optionally handles character encoding.
See String.
Parameters:
* subcon - the subcon to convert
* encoding - the character encoding name (e.g., "utf8"), or None to
return raw bytes (usually 8-bit ASCII).
"""
__slots__ = ["encoding"]
def __init__(self, subcon, encoding = None):
Adapter.__init__(self, subcon)
self.encoding = encoding
def _encode(self, obj, context):
if self.encoding:
obj = obj.encode(self.encoding)
return obj
def _decode(self, obj, context):
if self.encoding:
obj = obj.decode(self.encoding)
return obj
class PaddedStringAdapter(Adapter):
r"""
Adapter for padded strings.
See String.
Parameters:
* subcon - the subcon to adapt
* padchar - the padding character. default is "\x00".
* paddir - the direction where padding is placed ("right", "left", or
"center"). the default is "right".
* trimdir - the direction where trimming will take place ("right" or
"left"). the default is "right". trimming is only meaningful for
building, when the given string is too long.
"""
__slots__ = ["padchar", "paddir", "trimdir"]
def __init__(self, subcon, padchar = "\x00", paddir = "right",
trimdir = "right"):
if paddir not in ("right", "left", "center"):
raise ValueError("paddir must be 'right', 'left' or 'center'",
paddir)
if trimdir not in ("right", "left"):
raise ValueError("trimdir must be 'right' or 'left'", trimdir)
Adapter.__init__(self, subcon)
self.padchar = padchar
self.paddir = paddir
self.trimdir = trimdir
def _decode(self, obj, context):
if self.paddir == "right":
obj = obj.rstrip(self.padchar)
elif self.paddir == "left":
obj = obj.lstrip(self.padchar)
else:
obj = obj.strip(self.padchar)
return obj
def _encode(self, obj, context):
size = self._sizeof(context)
if self.paddir == "right":
obj = obj.ljust(size, self.padchar)
elif self.paddir == "left":
obj = obj.rjust(size, self.padchar)
else:
obj = obj.center(size, self.padchar)
if len(obj) > size:
if self.trimdir == "right":
obj = obj[:size]
else:
obj = obj[-size:]
return obj
class LengthValueAdapter(Adapter):
"""
Adapter for length-value pairs. It extracts only the value from the
pair, and calculates the length based on the value.
See PrefixedArray and PascalString.
Parameters:
* subcon - the subcon returning a length-value pair
"""
__slots__ = []
def _encode(self, obj, context):
return (len(obj), obj)
def _decode(self, obj, context):
return obj[1]
class CStringAdapter(StringAdapter):
r"""
Adapter for C-style strings (strings terminated by a terminator char).
Parameters:
* subcon - the subcon to convert
* terminators - a sequence of terminator chars. default is "\x00".
* encoding - the character encoding to use (e.g., "utf8"), or None to
return raw-bytes. the terminator characters are not affected by the
encoding.
"""
__slots__ = ["terminators"]
def __init__(self, subcon, terminators = b"\x00", encoding = None):
StringAdapter.__init__(self, subcon, encoding = encoding)
self.terminators = terminators
def _encode(self, obj, context):
return StringAdapter._encode(self, obj, context) + self.terminators[0:1]
def _decode(self, obj, context):
return StringAdapter._decode(self, b''.join(obj[:-1]), context)
class TunnelAdapter(Adapter):
"""
Adapter for tunneling (as in protocol tunneling). A tunnel is construct
nested upon another (layering). For parsing, the lower layer first parses
the data (note: it must return a string!), then the upper layer is called
to parse that data (bottom-up). For building it works in a top-down manner;
first the upper layer builds the data, then the lower layer takes it and
writes it to the stream.
Parameters:
* subcon - the lower layer subcon
* inner_subcon - the upper layer (tunneled/nested) subcon
Example:
# a pascal string containing compressed data (zlib encoding), so first
# the string is read, decompressed, and finally re-parsed as an array
# of UBInt16
TunnelAdapter(
PascalString("data", encoding = "zlib"),
GreedyRange(UBInt16("elements"))
)
"""
__slots__ = ["inner_subcon"]
def __init__(self, subcon, inner_subcon):
Adapter.__init__(self, subcon)
self.inner_subcon = inner_subcon
def _decode(self, obj, context):
return self.inner_subcon._parse(BytesIO(obj), context)
def _encode(self, obj, context):
stream = BytesIO()
self.inner_subcon._build(obj, stream, context)
return stream.getvalue()
class ExprAdapter(Adapter):
"""
A generic adapter that accepts 'encoder' and 'decoder' as parameters. You
can use ExprAdapter instead of writing a full-blown class when only a
simple expression is needed.
Parameters:
* subcon - the subcon to adapt
* encoder - a function that takes (obj, context) and returns an encoded
version of obj
* decoder - a function that takes (obj, context) and returns an decoded
version of obj
Example:
ExprAdapter(UBInt8("foo"),
encoder = lambda obj, ctx: obj / 4,
decoder = lambda obj, ctx: obj * 4,
)
"""
__slots__ = ["_encode", "_decode"]
def __init__(self, subcon, encoder, decoder):
Adapter.__init__(self, subcon)
self._encode = encoder
self._decode = decoder
class HexDumpAdapter(Adapter):
"""
Adapter for hex-dumping strings. It returns a HexString, which is a string
"""
__slots__ = ["linesize"]
def __init__(self, subcon, linesize = 16):
Adapter.__init__(self, subcon)
self.linesize = linesize
def _encode(self, obj, context):
return obj
def _decode(self, obj, context):
return HexString(obj, linesize = self.linesize)
class ConstAdapter(Adapter):
"""
Adapter for enforcing a constant value ("magic numbers"). When decoding,
the return value is checked; when building, the value is substituted in.
Parameters:
* subcon - the subcon to validate
* value - the expected value
Example:
Const(Field("signature", 2), "MZ")
"""
__slots__ = ["value"]
def __init__(self, subcon, value):
Adapter.__init__(self, subcon)
self.value = value
def _encode(self, obj, context):
if obj is None or obj == self.value:
return self.value
else:
raise ConstError("expected %r, found %r" % (self.value, obj))
def _decode(self, obj, context):
if obj != self.value:
raise ConstError("expected %r, found %r" % (self.value, obj))
return obj
class SlicingAdapter(Adapter):
"""
Adapter for slicing a list (getting a slice from that list)
Parameters:
* subcon - the subcon to slice
* start - start index
* stop - stop index (or None for up-to-end)
* step - step (or None for every element)
"""
__slots__ = ["start", "stop", "step"]
def __init__(self, subcon, start, stop = None):
Adapter.__init__(self, subcon)
self.start = start
self.stop = stop
def _encode(self, obj, context):
if self.start is None:
return obj
return [None] * self.start + obj
def _decode(self, obj, context):
return obj[self.start:self.stop]
class IndexingAdapter(Adapter):
"""
Adapter for indexing a list (getting a single item from that list)
Parameters:
* subcon - the subcon to index
* index - the index of the list to get
"""
__slots__ = ["index"]
def __init__(self, subcon, index):
Adapter.__init__(self, subcon)
if type(index) is not int:
raise TypeError("index must be an integer", type(index))
self.index = index
def _encode(self, obj, context):
return [None] * self.index + [obj]
def _decode(self, obj, context):
return obj[self.index]
class PaddingAdapter(Adapter):
r"""
Adapter for padding.
Parameters:
* subcon - the subcon to pad
* pattern - the padding pattern (character). default is "\x00"
* strict - whether or not to verify, during parsing, that the given
padding matches the padding pattern. default is False (unstrict)
"""
__slots__ = ["pattern", "strict"]
def __init__(self, subcon, pattern = "\x00", strict = False):
Adapter.__init__(self, subcon)
self.pattern = pattern
self.strict = strict
def _encode(self, obj, context):
return self._sizeof(context) * self.pattern
def _decode(self, obj, context):
if self.strict:
expected = self._sizeof(context) * self.pattern
if obj != expected:
raise PaddingError("expected %r, found %r" % (expected, obj))
return obj
#===============================================================================
# validators
#===============================================================================
class Validator(Adapter):
"""
Abstract class: validates a condition on the encoded/decoded object.
Override _validate(obj, context) in deriving classes.
Parameters:
* subcon - the subcon to validate
"""
__slots__ = []
def _decode(self, obj, context):
if not self._validate(obj, context):
raise ValidationError("invalid object", obj)
return obj
def _encode(self, obj, context):
return self._decode(obj, context)
def _validate(self, obj, context):
raise NotImplementedError()
class OneOf(Validator):
"""
Validates that the object is one of the listed values.
:param ``Construct`` subcon: object to validate
:param iterable valids: a set of valid values
>>> OneOf(UBInt8("foo"), [4,5,6,7]).parse("\\x05")
5
>>> OneOf(UBInt8("foo"), [4,5,6,7]).parse("\\x08")
Traceback (most recent call last):
...
construct.core.ValidationError: ('invalid object', 8)
>>>
>>> OneOf(UBInt8("foo"), [4,5,6,7]).build(5)
'\\x05'
>>> OneOf(UBInt8("foo"), [4,5,6,7]).build(9)
Traceback (most recent call last):
...
construct.core.ValidationError: ('invalid object', 9)
"""
__slots__ = ["valids"]
def __init__(self, subcon, valids):
Validator.__init__(self, subcon)
self.valids = valids
def _validate(self, obj, context):
return obj in self.valids
class NoneOf(Validator):
"""
Validates that the object is none of the listed values.
:param ``Construct`` subcon: object to validate
:param iterable invalids: a set of invalid values
>>> NoneOf(UBInt8("foo"), [4,5,6,7]).parse("\\x08")
8
>>> NoneOf(UBInt8("foo"), [4,5,6,7]).parse("\\x06")
Traceback (most recent call last):
...
construct.core.ValidationError: ('invalid object', 6)
"""
__slots__ = ["invalids"]
def __init__(self, subcon, invalids):
Validator.__init__(self, subcon)
self.invalids = invalids
def _validate(self, obj, context):
return obj not in self.invalids