#!/usr/bin/env python2.6

# Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import os
import re
import struct
import subprocess
import sys
import tempfile

# TODO(clchiou): Rewrite this part after official flashmap implementation is
# pulled into Chromium OS code base

# constants imported from lib/fmap.h
FMAP_SIGNATURE = "__FMAP__"
FMAP_VER_MAJOR = 1
FMAP_VER_MINOR = 0
FMAP_STRLEN = 32

FMAP_AREA_STATIC = 1 << 0
FMAP_AREA_COMPRESSED = 1 << 1
FMAP_AREA_RO = 1 << 2

FMAP_HEADER_FORMAT = "<8sBBQI%dsH" % (FMAP_STRLEN)
FMAP_AREA_FORMAT = "<II%dsH" % (FMAP_STRLEN)

FMAP_HEADER_NAMES = (
    'signature',
    'ver_major',
    'ver_minor',
    'base',
    'size',
    'name',
    'nareas',
)

FMAP_AREA_NAMES = (
    'offset',
    'size',
    'name',
    'flags',
)

RE_ASSIGNMENT = re.compile(r'^(\w+)=(.*)$')

VERBOSE = False


class ConfigError(Exception):
  pass


class PackError(Exception):
  pass


class Entry(dict):

  @staticmethod
  def _CheckFields(kwargs, fields):
    for f in fields:
      if f not in kwargs:
        raise ConfigError('Entry: missing required field: %s' % f)

  def __init__(self, **kwargs):
    Entry._CheckFields(kwargs, ('offset', 'length', 'name'))
    super(Entry, self).__init__(kwargs)

  def __getattr__(self, name):
    return self[name]

  def IsOverlapped(self, entry):
    return (entry.offset <= self.offset < entry.offset + entry.length or
            self.offset <= entry.offset < self.offset + self.length)

  def Pack(self, firmware_image, entries):
    raise PackError('class Entry does not implement Pack()')


class EntryFmap(Entry):

  def __init__(self, **kwargs):
    Entry._CheckFields(kwargs, ('ver_major', 'ver_minor', 'base', 'size'))
    super(EntryFmap, self).__init__(**kwargs)

  def Pack(self, firmware_image, entries):
    # prepare header areas
    areas = []
    for e in entries:
      if isinstance(e, EntryFmapArea):
        areas.append(dict((name, e[name] if name != 'size' else e['length'])
                          for name in FMAP_AREA_NAMES))

    # prepare header
    obj = {'areas':areas}
    for name in FMAP_HEADER_NAMES:
      if name == 'nareas':
        v = len(areas)
      elif name == 'signature':
        v = FMAP_SIGNATURE
      else:
        v = self[name]
      obj[name] = v

    blob = fmap_encode(obj)

    if len(blob) > self.length:
        raise PackError('fmap too large: %d > %d' % (len(blob), self.length))

    firmware_image.seek(self.offset)
    firmware_image.write(blob)


class EntryFmapArea(Entry):

  def __init__(self, **kwargs):
    Entry._CheckFields(kwargs, ('flags',))
    super(EntryFmapArea, self).__init__(**kwargs)

  def Pack(self, firmware_image, entries):
    pass


class EntryWiped(EntryFmapArea):

  def __init__(self, **kwargs):
    Entry._CheckFields(kwargs, ('wipe_value',))
    super(EntryWiped, self).__init__(**kwargs)
    if type(self.wipe_value) is int:
      try:
        self.wipe_value = chr(self.wipe_value)
      except ValueError as e:
        raise PackError('cannot convert wipe_value to a character: %s' % str(e))
    elif type(self.wipe_value) is str:
      if len(self.wipe_value) != 1:
        raise PackError('wipe_value out of range [00:ff]: %s' %
            repr(self.wipe_value))
    else:
      raise PackError('wipe_value is neither int nor str: %s' %
          repr(self.wipe_value))

  def Pack(self, firmware_image, entries):
    firmware_image.seek(self.offset)
    firmware_image.write(self.wipe_value * self.length)


class EntryBlob(EntryFmapArea):

  def __init__(self, **kwargs):
    Entry._CheckFields(kwargs, ('path',))
    super(EntryBlob, self).__init__(**kwargs)

  def Pack(self, firmware_image, entries):
    size = os.stat(self.path).st_size
    if size > self.length:
      raise PackError('blob too large: %s: %d > %d' %
          (self.path, size, self.length))
    if size == 0: # special case for files like /dev/zero
      size = self.length
    with open(self.path, 'rb') as blob_image:
      firmware_image.seek(self.offset)
      firmware_image.write(blob_image.read(size))


class EntryKeyBlock(EntryFmapArea):

  stdout = None
  stderr = None

  def __init__(self, **kwargs):
    Entry._CheckFields(kwargs,
        ('keyblock', 'signprivate', 'version', 'fv', 'kernelkey'))
    super(EntryKeyBlock, self).__init__(**kwargs)
    if VERBOSE:
      EntryKeyBlock.stdout = sys.stdout
      EntryKeyBlock.stderr = sys.stderr

  def Pack(self, firmware_image, entries):
    fd, path = tempfile.mkstemp()
    try:
      args = [
          'vbutil_firmware',
          '--vblock', path,
          '--keyblock', self.keyblock,
          '--signprivate', self.signprivate,
          '--version', '%d' % self.version,
          '--fv', self.fv,
          '--kernelkey', self.kernelkey,
      ]
      _Info('run: %s' % ' '.join(args))
      proc = subprocess.Popen(args,
          stdout=EntryKeyBlock.stdout, stderr=EntryKeyBlock.stderr)
      proc.wait()
      if proc.returncode != 0:
        raise PackError('cannot make key block: vbutil_firmware returns %d' %
			proc.returncode)

      size = os.stat(path).st_size
      if size > self.length:
        raise PackError('key block too large: %d > %d' % (size, self.length))

      with open(path, 'rb') as keyblock_image:
        firmware_image.seek(self.offset)
        firmware_image.write(keyblock_image.read())
    finally:
      os.unlink(path)


# TODO(clchiou): Keep fmap_encode interface compatible with official's flashmap
# implementation, and remove it after it is pulled in.
def fmap_encode(obj):
  def _FormatBlob(format, names, obj):
    return struct.pack(format, *(obj[name] for name in names))
  obj['nareas'] = len(obj['areas'])
  blob = _FormatBlob(FMAP_HEADER_FORMAT, FMAP_HEADER_NAMES, obj)
  for area in obj['areas']:
    blob = blob + _FormatBlob(FMAP_AREA_FORMAT, FMAP_AREA_NAMES, area)
  return blob


def parse_assignment(stmt):
  m = RE_ASSIGNMENT.match(stmt)
  if m is None:
    raise ConfigError('illegal statement: %s' % repr(stmt))
  return (m.group(1), parse_value(m.group(2)))


def parse_value(expr):
  if ((expr.startswith('"') and expr.endswith('"')) or
      (expr.startswith("'") and expr.endswith("'"))):
    return expr[1:-1] # if it is quoted, always interpreted as string literals
  try:
    return int(expr, 0)
  except ValueError:
    return expr # if not a number, interpret as string literals


def pack_firmware_image(entries, output_path, image_size):
  entries = sorted(entries, key=lambda e: e.offset)
  for e1, e2 in zip(entries, entries[1:]):
    # Allow overlap between "pure" fmap areas, but not any of its subclasses
    # Here we exploit the fact that Entry is a new-style class
    if (e1.IsOverlapped(e2) and
        type(e1) is not EntryFmapArea and type(e2) is not EntryFmapArea):
      raise PackError('overlapped entries: [%08x:%08x], [%08x:%08x]' %
          (e1.offset, e1.offset + e1.length, e2.offset, e2.offset + e2.length))

  with open(output_path, 'wb') as firmware_image:
    # resize firmware image file
    firmware_image.seek(0)
    firmware_image.write('\0' * image_size)

    for entry in entries:
      entry.Pack(firmware_image, entries)


def _Info(msg):
  if VERBOSE:
    print >>sys.stderr, 'INFO: %s' % msg


def main():
  global VERBOSE

  if len(sys.argv) < 2:
    print 'Usage: %s [-v] CONFIG_FILE [NAME=VALUE...]' % sys.argv[0]
    sys.exit(1)

  if sys.argv[1] == '-v':
    VERBOSE = True
    argv = sys.argv[0:1] + sys.argv[2:]
  else:
    argv = sys.argv

  if len(argv) > 2:
    env = dict(parse_assignment(stmt) for stmt in argv[2:])
  else:
    env = {}

  execfile(argv[1], globals(), env)

  for varname in ('ENTRIES', 'OUTPUT', 'SIZE'):
    if varname not in env:
      raise ConfigError('undefined variable: %s' % varname)
    _Info('%s = %s' % (varname, repr(env[varname])))

  pack_firmware_image(env['ENTRIES'], env['OUTPUT'], env['SIZE'])

  sys.exit(0)


if __name__ == '__main__':
  main()
