blob: a598fbd552aa55357114304a22e1ac20ef518d50 [file] [log] [blame] [edit]
# Copyright 2011 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Unit tests for the table module."""
import tempfile
from chromite.lib import cros_test_lib
from chromite.lib import table
from chromite.utils import outcap
# pylint: disable=protected-access
class TableTest(cros_test_lib.TempDirTestCase):
"""Unit tests for the Table class."""
COL0 = "Column1"
COL1 = "Column2"
COL2 = "Column3"
COL3 = "Column4"
COLUMNS = [COL0, COL1, COL2, COL3]
ROW0 = {COL0: "Xyz", COL1: "Bcd", COL2: "Cde"}
ROW1 = {COL0: "Abc", COL1: "Bcd", COL2: "Opq", COL3: "Foo"}
ROW2 = {COL0: "Abc", COL1: "Nop", COL2: "Wxy", COL3: "Bar"}
EXTRAROW = {COL1: "Walk", COL2: "The", COL3: "Line"}
EXTRAROWOUT = {COL0: "", COL1: "Walk", COL2: "The", COL3: "Line"}
ROW0a = {COL0: "Xyz", COL1: "Bcd", COL2: "Cde", COL3: "Yay"}
ROW0b = {COL0: "Xyz", COL1: "Bcd", COL2: "Cde", COL3: "Boo"}
ROW1a = {COL0: "Abc", COL1: "Bcd", COL2: "Opq", COL3: "Blu"}
EXTRACOL = "ExtraCol"
EXTRACOLUMNS = [COL0, EXTRACOL, COL1, COL2]
EROW0 = {COL0: "Xyz", EXTRACOL: "Yay", COL1: "Bcd", COL2: "Cde"}
EROW1 = {COL0: "Abc", EXTRACOL: "Hip", COL1: "Bcd", COL2: "Opq"}
EROW2 = {COL0: "Abc", EXTRACOL: "Yay", COL1: "Nop", COL2: "Wxy"}
def _GetRowValsInOrder(self, row):
"""Take |row| dict and return correctly ordered values in a list."""
vals = []
for col in self.COLUMNS:
vals.append(row.get(col, ""))
return vals
def _GetFullRowFor(self, row, cols):
return dict((col, row.get(col, "")) for col in cols)
def assertRowsEqual(self, row1, row2):
# Determine column superset
cols = set(row1) | set(row2)
self.assertEqual(
self._GetFullRowFor(row1, cols), self._GetFullRowFor(row2, cols)
)
def assertRowListsEqual(self, rows1, rows2):
for row1, row2 in zip(rows1, rows2):
self.assertRowsEqual(row1, row2)
def setUp(self):
self._table = self._CreateTableWithRows(
self.COLUMNS, [self.ROW0, self.ROW1, self.ROW2]
)
def _CreateTableWithRows(self, cols, rows):
mytable = table.Table(list(cols))
if rows:
for row in rows:
mytable.AppendRow(dict(row))
return mytable
def testLen(self):
self.assertEqual(3, len(self._table))
def testGetNumRows(self):
self.assertEqual(3, self._table.GetNumRows())
def testGetNumColumns(self):
self.assertEqual(4, self._table.GetNumColumns())
def testGetColumns(self):
self.assertEqual(self.COLUMNS, self._table.GetColumns())
def testGetColumnIndex(self):
self.assertEqual(0, self._table.GetColumnIndex(self.COL0))
self.assertEqual(1, self._table.GetColumnIndex(self.COL1))
self.assertEqual(2, self._table.GetColumnIndex(self.COL2))
def testGetColumnByIndex(self):
self.assertEqual(self.COL0, self._table.GetColumnByIndex(0))
self.assertEqual(self.COL1, self._table.GetColumnByIndex(1))
self.assertEqual(self.COL2, self._table.GetColumnByIndex(2))
def testGetByIndex(self):
self.assertRowsEqual(self.ROW0, self._table.GetRowByIndex(0))
self.assertRowsEqual(self.ROW0, self._table[0])
self.assertRowsEqual(self.ROW2, self._table.GetRowByIndex(2))
self.assertRowsEqual(self.ROW2, self._table[2])
def testSlice(self):
self.assertRowListsEqual([self.ROW0, self.ROW1], self._table[0:2])
self.assertRowListsEqual([self.ROW2], self._table[-1:])
def testGetByValue(self):
rows = self._table.GetRowsByValue({self.COL0: "Abc"})
self.assertEqual([self.ROW1, self.ROW2], rows)
rows = self._table.GetRowsByValue({self.COL2: "Opq"})
self.assertEqual([self.ROW1], rows)
rows = self._table.GetRowsByValue({self.COL3: "Foo"})
self.assertEqual([self.ROW1], rows)
def testGetIndicesByValue(self):
indices = self._table.GetRowIndicesByValue({self.COL0: "Abc"})
self.assertEqual([1, 2], indices)
indices = self._table.GetRowIndicesByValue({self.COL2: "Opq"})
self.assertEqual([1], indices)
indices = self._table.GetRowIndicesByValue({self.COL3: "Foo"})
self.assertEqual([1], indices)
def testAppendRowDict(self):
self._table.AppendRow(dict(self.EXTRAROW))
self.assertEqual(4, self._table.GetNumRows())
self.assertEqual(self.EXTRAROWOUT, self._table[len(self._table) - 1])
def testAppendRowList(self):
self._table.AppendRow(self._GetRowValsInOrder(self.EXTRAROW))
self.assertEqual(4, self._table.GetNumRows())
self.assertEqual(self.EXTRAROWOUT, self._table[len(self._table) - 1])
def testSetRowDictByIndex(self):
self._table.SetRowByIndex(1, dict(self.EXTRAROW))
self.assertEqual(3, self._table.GetNumRows())
self.assertEqual(self.EXTRAROWOUT, self._table[1])
def testSetRowListByIndex(self):
self._table.SetRowByIndex(1, self._GetRowValsInOrder(self.EXTRAROW))
self.assertEqual(3, self._table.GetNumRows())
self.assertEqual(self.EXTRAROWOUT, self._table[1])
def testRemoveRowByIndex(self):
self._table.RemoveRowByIndex(1)
self.assertEqual(2, self._table.GetNumRows())
self.assertEqual(self.ROW2, self._table[1])
def testRemoveRowBySlice(self):
del self._table[0:2]
self.assertEqual(1, self._table.GetNumRows())
self.assertEqual(self.ROW2, self._table[0])
def testIteration(self):
ix = 0
for row in self._table:
self.assertEqual(row, self._table[ix])
ix += 1
def testClear(self):
self._table.Clear()
self.assertEqual(0, len(self._table))
def testMergeRows(self):
# This merge should fail without a merge rule. Capture stderr to avoid
# scary error message in test output.
with outcap.OutputCapturer():
self.assertRaises(
ValueError, self._table._MergeRow, self.ROW0a, self.COL0
)
# Merge but stick with current row where different.
self._table._MergeRow(
self.ROW0a, self.COL0, merge_rules={self.COL3: "accept_this_val"}
)
self.assertEqual(3, len(self._table))
self.assertRowsEqual(self.ROW0, self._table[0])
# Merge and use new row where different.
self._table._MergeRow(
self.ROW0a, self.COL0, merge_rules={self.COL3: "accept_other_val"}
)
self.assertEqual(3, len(self._table))
self.assertRowsEqual(self.ROW0a, self._table[0])
# Merge and combine column values where different
self._table._MergeRow(
self.ROW1a, self.COL2, merge_rules={self.COL3: "join_with: "}
)
self.assertEqual(3, len(self._table))
final_row = dict(self.ROW1a)
final_row[self.COL3] = (
self.ROW1[self.COL3] + " " + self.ROW1a[self.COL3]
)
self.assertRowsEqual(final_row, self._table[1])
def testMergeTablesSameCols(self):
other_table = self._CreateTableWithRows(
self.COLUMNS, [self.ROW0b, self.ROW1a, self.ROW2]
)
self._table.MergeTable(
other_table, self.COL2, merge_rules={self.COL3: "join_with: "}
)
final_row0 = self.ROW0b
final_row1 = dict(self.ROW1a)
final_row1[self.COL3] = (
self.ROW1[self.COL3] + " " + self.ROW1a[self.COL3]
)
final_row2 = self.ROW2
self.assertRowsEqual(final_row0, self._table[0])
self.assertRowsEqual(final_row1, self._table[1])
self.assertRowsEqual(final_row2, self._table[2])
def testMergeTablesNewCols(self):
self.assertFalse(self._table.HasColumn(self.EXTRACOL))
other_rows = [self.EROW0, self.EROW1, self.EROW2]
other_table = self._CreateTableWithRows(self.EXTRACOLUMNS, other_rows)
self._table.MergeTable(
other_table,
self.COL2,
allow_new_columns=True,
merge_rules={self.COL3: "join_by_space"},
)
self.assertTrue(self._table.HasColumn(self.EXTRACOL))
self.assertEqual(5, self._table.GetNumColumns())
self.assertEqual(1, self._table.GetColumnIndex(self.EXTRACOL))
final_row0 = dict(self.ROW0)
final_row0[self.EXTRACOL] = self.EROW0[self.EXTRACOL]
final_row1 = dict(self.ROW1)
final_row1[self.EXTRACOL] = self.EROW1[self.EXTRACOL]
final_row2 = dict(self.ROW2)
final_row2[self.EXTRACOL] = self.EROW2[self.EXTRACOL]
self.assertRowsEqual(final_row0, self._table[0])
self.assertRowsEqual(final_row1, self._table[1])
self.assertRowsEqual(final_row2, self._table[2])
def testSort1(self):
self.assertRowsEqual(self.ROW0, self._table[0])
self.assertRowsEqual(self.ROW1, self._table[1])
self.assertRowsEqual(self.ROW2, self._table[2])
# Sort by COL3
self._table.Sort(lambda row: row[self.COL3])
self.assertEqual(3, len(self._table))
self.assertRowsEqual(self.ROW0, self._table[0])
self.assertRowsEqual(self.ROW2, self._table[1])
self.assertRowsEqual(self.ROW1, self._table[2])
# Reverse sort by COL3
self._table.Sort(lambda row: row[self.COL3], reverse=True)
self.assertEqual(3, len(self._table))
self.assertRowsEqual(self.ROW1, self._table[0])
self.assertRowsEqual(self.ROW2, self._table[1])
self.assertRowsEqual(self.ROW0, self._table[2])
def testSort2(self):
"""Test multiple key sort."""
self.assertRowsEqual(self.ROW0, self._table[0])
self.assertRowsEqual(self.ROW1, self._table[1])
self.assertRowsEqual(self.ROW2, self._table[2])
# Sort by COL0 then COL1
def sorter(row):
return (row[self.COL0], row[self.COL1])
self._table.Sort(sorter)
self.assertEqual(3, len(self._table))
self.assertRowsEqual(self.ROW1, self._table[0])
self.assertRowsEqual(self.ROW2, self._table[1])
self.assertRowsEqual(self.ROW0, self._table[2])
# Reverse the sort
self._table.Sort(sorter, reverse=True)
self.assertEqual(3, len(self._table))
self.assertRowsEqual(self.ROW0, self._table[0])
self.assertRowsEqual(self.ROW2, self._table[1])
self.assertRowsEqual(self.ROW1, self._table[2])
def testSplitCSVLine(self):
"""Test splitting of csv line."""
tests = {
"a,b,c,d": ["a", "b", "c", "d"],
"a, b, c, d": ["a", " b", " c", " d"],
"a,b,c,": ["a", "b", "c", ""],
'a,"b c",d': ["a", "b c", "d"],
'a,"b, c",d': ["a", "b, c", "d"],
'a,"b, c, d",e': ["a", "b, c, d", "e"],
'a,"""b, c""",d': ["a", '"b, c"', "d"],
'a,"""b, c"", d",e': ["a", '"b, c", d', "e"],
# Following not real Google Spreadsheet cases.
r"a,b\,c,d": ["a", "b,c", "d"],
'a,",c': ["a", '",c'],
'a,"",c': ["a", "", "c"],
}
# TODO(b/236161656): Fix.
# pylint: disable-next=consider-using-dict-items
for line in tests:
vals = table.Table._SplitCSVLine(line)
self.assertEqual(vals, tests[line])
def testWriteReadCSV(self):
"""Write and Read CSV and verify contents preserved."""
# This also tests the Table == and != operators.
_, path = tempfile.mkstemp(text=True)
with open(path, "w", encoding="utf-8") as tmpfile:
self._table.WriteCSV(tmpfile)
mytable = table.Table.LoadFromCSV(path)
self.assertEqual(mytable, self._table)
self.assertFalse(mytable != self._table)
def testInsertColumn(self):
self._table.InsertColumn(1, self.EXTRACOL, "blah")
goldenrow = dict(self.ROW1)
goldenrow[self.EXTRACOL] = "blah"
self.assertRowsEqual(goldenrow, self._table.GetRowByIndex(1))
self.assertEqual(self.EXTRACOL, self._table.GetColumnByIndex(1))
def testAppendColumn(self):
self._table.AppendColumn(self.EXTRACOL, "blah")
goldenrow = dict(self.ROW1)
goldenrow[self.EXTRACOL] = "blah"
self.assertRowsEqual(goldenrow, self._table.GetRowByIndex(1))
col_size = self._table.GetNumColumns()
self.assertEqual(
self.EXTRACOL, self._table.GetColumnByIndex(col_size - 1)
)
def testProcessRows(self):
def Processor(row):
row[self.COL0] = row[self.COL0] + " processed"
self._table.ProcessRows(Processor)
final_row0 = dict(self.ROW0)
final_row0[self.COL0] += " processed"
final_row1 = dict(self.ROW1)
final_row1[self.COL0] += " processed"
final_row2 = dict(self.ROW2)
final_row2[self.COL0] += " processed"
self.assertRowsEqual(final_row0, self._table[0])
self.assertRowsEqual(final_row1, self._table[1])
self.assertRowsEqual(final_row2, self._table[2])