#!/usr/bin/python
# Copyright (c) 2011 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Support generic spreadsheet-like table information."""

import inspect
import re
import sys

from chromite.lib import cros_build_lib


class Table(object):
  """Class to represent column headers and rows of data."""

  __slots__ = ['_column_set',  # Set of column headers (for faster lookup)
               '_columns',     # List of column headers in order
               '_name',        # Name to associate with table
               '_rows',        # List of row dicts
               ]

  EMPTY_CELL = ''

  CSV_BQ = '__BEGINQUOTE__'
  CSV_EQ = '__ENDQUOTE__'

  @staticmethod
  def _SplitCSVLine(line):
    '''Split a single CSV line into separate values.

    Behavior illustrated by the following examples, with all but
    the last example taken from Google Docs spreadsheet behavior:
    'a,b,c,d':           ==> ['a', 'b', 'c', 'd'],
    'a, b, c, d':        ==> ['a', ' b', ' c', ' d'],
    'a,b,c,':            ==> ['a', 'b', 'c', ''],
    'a,"b c",d':         ==> ['a', 'b c', 'd'],
    'a,"b, c",d':        ==> ['a', 'b, c', 'd'],
    'a,"b, c, d",e':     ==> ['a', 'b, c, d', 'e'],
    'a,"""b, c""",d':    ==> ['a', '"b, c"', 'd'],
    'a,"""b, c"", d",e': ==> ['a', '"b, c", d', 'e'],
    'a,b\,c,d':          ==> ['a', 'b,c', 'd'],

    Return a list of values.'''
    # Split on commas, handling two special cases:
    # 1) Escaped commas are not separators.
    # 2) A quoted value can have non-separator commas in it.  Quotes
    #    should be removed.
    vals = []
    for val in re.split(r'(?<!\\),', line):
      if not val:
        vals.append(val)
        continue

      # Handle regular double quotes at beginning/end specially.
      if val[0] == '"':
        val = Table.CSV_BQ + val[1:]
      if val[-1] == '"' and (val[-2] != '"' or val[-3] == '"'):
        val = val[0:-1] + Table.CSV_EQ

      # Remove escape characters now.
      val = val.replace(r'\,', ',')  # \ before ,
      val = val.replace('""', '"')   # " before " (Google Spreadsheet syntax)

      prevval = vals[-1] if vals else None

      # If previous value started with quote and ended without one, then
      # the current value is just a continuation of the previous value.
      if prevval and prevval.startswith(Table.CSV_BQ):
        val = prevval + "," + val
        # Once entire value is read, strip surrounding quotes
        if val.endswith(Table.CSV_EQ):
          vals[-1] = val[len(Table.CSV_BQ):-len(Table.CSV_EQ)]
        else:
          vals[-1] = val
      elif val.endswith(Table.CSV_EQ):
        vals.append(val[len(Table.CSV_BQ):-len(Table.CSV_EQ)])
      else:
        vals.append(val)

    # If an unpaired Table.CSV_BQ is still in vals, then replace with ".
    vals = [val.replace(Table.CSV_BQ, '"') for val in vals]

    return vals

  @staticmethod
  def LoadFromCSV(csv_file, name=None):
    """Create a new Table object by loading contents of |csv_file|."""
    if type(csv_file) is file:
      file_handle = csv_file
    else:
      file_handle = open(csv_file, 'r')
    table = None

    for line in file_handle:
      if line[-1] == '\n':
        line = line[0:-1]

      vals = Table._SplitCSVLine(line)

      if not table:
        # Read headers
        table = Table(vals, name=name)

      else:
        # Read data row
        table.AppendRow(vals)

    return table

  def __init__(self, columns, name=None):
    self._columns = columns
    self._column_set = set(columns)
    self._rows = []
    self._name = name

  def __str__(self):
    """Return a table-like string representation of this table."""
    cols = ['%10s' % col for col in self._columns]
    text = 'Columns: %s\n' % ', '.join(cols)

    ix = 0
    for row in self._rows:
      vals = ['%10s' % row[col] for col in self._columns]
      text += 'Row %3d: %s\n' % (ix, ', '.join(vals))
      ix += 1
    return text

  def __nonzero__(self):
    """Define boolean equivalent for this table."""
    return bool(self._columns)

  def __len__(self):
    """Length of table equals the number of rows."""
    return self.GetNumRows()

  def __eq__(self, other):
    """Return true if two tables are equal."""
    # pylint: disable=W0212
    return self._columns == other._columns and self._rows == other._rows

  def __ne__(self, other):
    """Return true if two tables are not equal."""
    return not self == other

  def __getitem__(self, index):
    """Access one or more rows by index or slice."""
    return self.GetRowByIndex(index)

  def __delitem__(self, index):
    """Delete one or more rows by index or slice."""
    self.RemoveRowByIndex(index)

  def __iter__(self):
    """Declare that this class supports iteration (over rows)."""
    return self._rows.__iter__()

  def GetName(self):
    """Return name associated with table, None if not available."""
    return self._name

  def SetName(self, name):
    """Set the name associated with table."""
    self._name = name

  def Clear(self):
    """Remove all row data."""
    self._rows = []

  def GetNumRows(self):
    """Return the number of rows in the table."""
    return len(self._rows)

  def GetNumColumns(self):
    """Return the number of columns in the table."""
    return len(self._columns)

  def GetColumns(self):
    """Return list of column names in order."""
    return list(self._columns)

  def GetRowByIndex(self, index):
    """Access one or more rows by index or slice.

    If more than one row is returned they will be contained in a list."""
    return self._rows[index]

  def _GenRowFilter(self, id_values):
    """Return a method that returns true for rows matching |id_values|."""
    def Grep(row):
      """Filter function for rows with id_values."""
      for key in id_values:
        if id_values[key] != row.get(key, None):
          return False
      return True
    return Grep

  def GetRowsByValue(self, id_values):
    """Return list of rows matching key/value pairs in |id_values|."""
    # If row retrieval by value is heavily used for larger tables, then
    # the implementation should change to be more efficient, at the
    # expense of some pre-processing and extra storage.
    grep = self._GenRowFilter(id_values)
    return [r for r in self._rows if grep(r)]

  def GetRowIndicesByValue(self, id_values):
    """Return list of indices for rows matching k/v pairs in |id_values|."""
    grep = self._GenRowFilter(id_values)
    indices = []
    for ix, row in enumerate(self._rows):
      if grep(row):
        indices.append(ix)

    return indices

  def _PrepareValuesForAdd(self, values):
    """Prepare a |values| dict/list to be added as a row.

    If |values| is a dict, verify that only supported column
    values are included. Add empty string values for columns
    not seen in the row.  The original dict may be altered.

    If |values| is a list, translate it to a dict using known
    column order.  Append empty values as needed to match number
    of expected columns.

    Return prepared dict.
    """
    if isinstance(values, dict):
      for col in values:
        if not col in self._column_set:
          raise LookupError("Tried adding data to unknown column '%s'" % col)

      for col in self._columns:
        if not col in values:
          values[col] = self.EMPTY_CELL

    elif isinstance(values, list):
      if len(values) > len(self._columns):
        raise LookupError("Tried adding row with too many columns")
      if len(values) < len(self._columns):
        shortage = len(self._columns) - len(values)
        values.extend([self.EMPTY_CELL] * shortage)

      values = dict(zip(self._columns, values))

    return values

  def AppendRow(self, values):
    """Add a single row of data to the table, according to |values|.

    The |values| argument can be either a dict or list."""
    row = self._PrepareValuesForAdd(values)
    self._rows.append(row)

  def SetRowByIndex(self, index, values):
    """Replace the row at |index| with values from |values| dict."""
    row = self._PrepareValuesForAdd(values)
    self._rows[index] = row

  def RemoveRowByIndex(self, index):
    """Remove the row at |index|."""
    del self._rows[index]

  def HasColumn(self, name):
    """Return True if column |name| is in this table, False otherwise."""
    return name in self._column_set

  def GetColumnIndex(self, name):
    """Return the column index for column |name|, -1 if not found."""
    for ix, col in enumerate(self._columns):
      if name == col:
        return ix
    return -1

  def GetColumnByIndex(self, index):
    """Return the column name at |index|"""
    return self._columns[index]

  def InsertColumn(self, index, name, value=None):
    """Insert a new column |name| into table at index |index|.

    If |value| is specified, all rows will have |value| in the new column.
    Otherwise, they will have the EMPTY_CELL value."""
    if self.HasColumn(name):
      raise LookupError("Column %s already exists in table." % name)

    self._columns.insert(index, name)
    self._column_set.add(name)

    for row in self._rows:
      row[name] = value if value is not None else self.EMPTY_CELL

  def AppendColumn(self, name, value=None):
    """Same as InsertColumn, but new column is appended after existing ones."""
    self.InsertColumn(self.GetNumColumns(), name, value)

  def ProcessRows(self, row_processor):
    """Invoke |row_processor| on each row in sequence."""
    for row in self._rows:
      row_processor(row)

  def MergeTable(self, other_table, id_columns, merge_rules=None,
                 allow_new_columns=False, key=None, reverse=False,
                 new_name=None):
    """Merge |other_table| into this table, identifying rows by |id_columns|.

    The |id_columns| argument can either be a list of identifying columns names
    or a single column name (string).  The values in these columns will be used
    to identify the existing row that each row in |other_table| should be
    merged into.

    The |merge_rules| specify what to do when there is a merge conflict.  Every
    column where a conflict is anticipated should have an entry in the
    |merge_rules| dict.  The value should be one of:
    'join_with:<text>| = Join the two conflicting values with <text>
    'accept_this_val' = Keep value in 'this' table and discard 'other' value.
    'accept_other_val' = Keep value in 'other' table and discard 'this' value.
    function = Keep return value from function(col_name, this_val, other_val)

    A default merge rule can be specified with the key '__DEFAULT__' in
    |merge_rules|.

    By default, the |other_table| must not have any columns that don't already
    exist in this table.  To allow new columns to be creating by virtue of their
    presence in |other_table| set |allow_new_columns| to true.

    To sort the final merged table, supply |key| and |reverse| arguments exactly
    as they work with the Sort method.
    """
    # If requested, allow columns in other_table to create new columns
    # in this table if this table does not already have them.
    if allow_new_columns:
      # pylint: disable=W0212
      for ix, col in enumerate(other_table._columns):
        if not self.HasColumn(col):
          # Create a merge_rule on the fly for this new column.
          if not merge_rules:
            merge_rules = {}
          merge_rules[col] = 'accept_other_val'

          if ix == 0:
            self.InsertColumn(0, col)
          else:
            prevcol = other_table._columns[ix - 1]
            previx = self.GetColumnIndex(prevcol)
            self.InsertColumn(previx + 1, col)

    for other_row in other_table:
      self._MergeRow(other_row, id_columns, merge_rules=merge_rules)

    # Optionally re-sort the merged table.
    if key:
      self.Sort(key, reverse=reverse)

    if new_name:
      self.SetName(new_name)
    elif self.GetName() and other_table.GetName():
      self.SetName(self.GetName() + ' + ' + other_table.GetName())

  def _GetIdValuesForRow(self, row, id_columns):
    """Return a dict with values from |row| in |id_columns|."""
    id_values = dict((col, row[col]) for col in
                     cros_build_lib.iflatten_instance(id_columns))
    return id_values

  def _MergeRow(self, other_row, id_columns, merge_rules=None):
    """Merge |other_row| into this table.

    See MergeTables for description of |id_columns| and |merge_rules|.
    """
    id_values = self._GetIdValuesForRow(other_row, id_columns)

    row_indices = self.GetRowIndicesByValue(id_values)
    if row_indices:
      row_index = row_indices[0]
      row = self.GetRowByIndex(row_index)
      for col in other_row:
        if col in row:
          # Find the merge rule that applies to this column, if there is one.
          merge_rule = None
          if merge_rules:
            merge_rule = merge_rules.get(col, None)
            if not merge_rule and merge_rules:
              merge_rule = merge_rules.get('__DEFAULT__', None)

          try:
            val = self._MergeColValue(col, row[col], other_row[col],
                                      merge_rule=merge_rule)
          except ValueError:
            msg = "Failed to merge '%s' value in row %r" % (col, id_values)
            print >> sys.stderr, msg
            raise

          if val != row[col]:
            row[col] = val
        else:
          # Cannot add new columns to row this way.
          raise LookupError("Tried merging data to unknown column '%s'" % col)
      self.SetRowByIndex(row_index, row)
    else:
      self.AppendRow(other_row)

  def _MergeColValue(self, col, val, other_val, merge_rule):
    """Merge |col| values |val| and |other_val| according to |merge_rule|.

    See MergeTable method for explanation of option |merge_rule|.
    """
    if val == other_val:
      return val

    if not merge_rule:
      raise ValueError("Cannot merge column values without rule: '%s' vs '%s'" %
                       (val, other_val))
    elif inspect.isfunction(merge_rule):
      try:
        return merge_rule(col, val, other_val)
      except ValueError:
        pass # Fall through to exception at end
    elif merge_rule == 'accept_this_val':
      return val
    elif merge_rule == 'accept_other_val':
      return other_val
    else:
      match = re.match(r'join_with:(.+)$', merge_rule)
      if match:
        return match.group(1).join(v for v in (val, other_val) if v)

    raise ValueError("Invalid merge rule (%s) for values '%s' and '%s'." %
                     (merge_rule, val, other_val))

  def Sort(self, key, reverse=False):
    """Sort the rows using the given |key| function."""
    self._rows.sort(key=key, reverse=reverse)

  def WriteCSV(self, filehandle, hiddencols=None):
    """Write this table out as comma-separated values to |filehandle|.

    To skip certain columns during the write, use the |hiddencols| set.
    """
    def ColFilter(col):
      """Filter function for columns not in hiddencols."""
      return not hiddencols or col not in hiddencols

    cols = [col for col in self._columns if ColFilter(col)]
    filehandle.write(','.join(cols) + '\n')
    for row in self._rows:
      vals = [row.get(col, self.EMPTY_CELL) for col in cols]
      filehandle.write(','.join(vals) + '\n')
