diff --git a/rows/__init__.py b/rows/__init__.py index 8f9f9374..e5563502 100644 --- a/rows/__init__.py +++ b/rows/__init__.py @@ -22,7 +22,7 @@ import rows.plugins as plugins from rows.operations import join, transform, transpose # NOQA -from rows.table import Table, FlexibleTable # NOQA +from rows.table import FlexibleTable, LazyTable, Table # NOQA from rows.localization import locale_context # NOQA diff --git a/rows/cli.py b/rows/cli.py index aada4425..8e1f631a 100755 --- a/rows/cli.py +++ b/rows/cli.py @@ -379,27 +379,30 @@ def query(input_encoding, output_encoding, input_locale, output_locale, if input_locale is not None: with rows.locale_context(input_locale): table = import_from_source(source, DEFAULT_INPUT_ENCODING, - samples=samples) + lazy=True, samples=samples) else: table = import_from_source(source, DEFAULT_INPUT_ENCODING, - samples=samples) + lazy=True, samples=samples) sqlite_connection = sqlite3.Connection(':memory:') rows.export_to_sqlite(table, sqlite_connection, table_name='table1') - result = rows.import_from_sqlite(sqlite_connection, query=query) + result = rows.import_from_sqlite(sqlite_connection, query=query, + lazy=True, samples=samples) else: # TODO: if all sources are SQLite we can also optimize the import if input_locale is not None: with rows.locale_context(input_locale): tables = [_import_table(source, encoding=input_encoding, - verify_ssl=verify_ssl, samples=samples) + verify_ssl=verify_ssl, lazy=True, + samples=samples) for source in sources] else: tables = [_import_table(source, encoding=input_encoding, - verify_ssl=verify_ssl, samples=samples) + verify_ssl=verify_ssl, lazy=True, + samples=samples) for source in sources] sqlite_connection = sqlite3.Connection(':memory:') @@ -408,7 +411,8 @@ def query(input_encoding, output_encoding, input_locale, output_locale, sqlite_connection, table_name='table{}'.format(index)) - result = rows.import_from_sqlite(sqlite_connection, query=query) + result = rows.import_from_sqlite(sqlite_connection, query=query, + lazy=True, samples=samples) # TODO: may use sys.stdout.encoding if output_file = '-' output_encoding = output_encoding or sys.stdout.encoding or \ diff --git a/rows/plugins/dicts.py b/rows/plugins/dicts.py index 713f07e4..12d1543d 100644 --- a/rows/plugins/dicts.py +++ b/rows/plugins/dicts.py @@ -52,6 +52,9 @@ def import_from_dicts(data, samples=None, *args, **kwargs): return create_table(chain([headers], data_rows), meta=meta, *args, **kwargs) +import_from_dicts.is_lazy = False + + def export_to_dicts(table, *args, **kwargs): """Export a `rows.Table` to a list of dicts""" field_names = table.field_names diff --git a/rows/plugins/ods.py b/rows/plugins/ods.py index af8e00cc..1bcf4d02 100644 --- a/rows/plugins/ods.py +++ b/rows/plugins/ods.py @@ -103,5 +103,10 @@ def import_from_ods(filename_or_fobj, index=0, *args, **kwargs): max_length = max(len(row) for row in table_rows) full_rows = complete_with_None(table_rows, max_length) + meta = {'imported_from': 'ods', 'filename': filename,} + return create_table(full_rows, meta=meta, *args, **kwargs) + + +import_from_ods.is_lazy = False diff --git a/rows/plugins/plugin_csv.py b/rows/plugins/plugin_csv.py index 99c19cd4..e162e9b0 100644 --- a/rows/plugins/plugin_csv.py +++ b/rows/plugins/plugin_csv.py @@ -118,6 +118,9 @@ def import_from_csv(filename_or_fobj, encoding='utf-8', dialect=None, return create_table(reader, meta=meta, *args, **kwargs) +import_from_csv.is_lazy = True + + def export_to_csv(table, filename_or_fobj=None, encoding='utf-8', dialect=unicodecsv.excel, batch_size=100, callback=None, *args, **kwargs): @@ -130,7 +133,6 @@ def export_to_csv(table, filename_or_fobj=None, encoding='utf-8', contents. """ # TODO: will work only if table.fields is OrderedDict - # TODO: should use fobj? What about creating a method like json.dumps? if filename_or_fobj is not None: _, fobj = get_filename_and_fobj(filename_or_fobj, mode='wb') diff --git a/rows/plugins/plugin_html.py b/rows/plugins/plugin_html.py index 9077c6f5..e4872ed2 100644 --- a/rows/plugins/plugin_html.py +++ b/rows/plugins/plugin_html.py @@ -97,6 +97,9 @@ def import_from_html(filename_or_fobj, encoding='utf-8', index=0, return create_table(table_rows, meta=meta, *args, **kwargs) +import_from_html.is_lazy = False + + def export_to_html(table, filename_or_fobj=None, encoding='utf-8', *args, **kwargs): """Export and return rows.Table data to HTML file.""" @@ -106,6 +109,7 @@ def export_to_html(table, filename_or_fobj=None, encoding='utf-8', *args, header = [' {} \n'.format(field) for field in fields] result.extend(header) result.extend([' \n', ' \n', '\n', ' \n', '\n']) + # TODO: could be lazy so we don't need to store the whole table into memory for index, row in enumerate(serialized_table, start=1): css_class = 'odd' if index % 2 == 1 else 'even' result.append(' \n'.format(css_class)) diff --git a/rows/plugins/plugin_json.py b/rows/plugins/plugin_json.py index 22628e03..3b767a72 100644 --- a/rows/plugins/plugin_json.py +++ b/rows/plugins/plugin_json.py @@ -35,6 +35,7 @@ def import_from_json(filename_or_fobj, encoding='utf-8', *args, **kwargs): filename, fobj = get_filename_and_fobj(filename_or_fobj) json_obj = json.load(fobj, encoding=encoding) + # TODO: may use import_from_dicts here field_names = list(json_obj[0].keys()) table_rows = [[item[key] for key in field_names] for item in json_obj] @@ -44,6 +45,9 @@ def import_from_json(filename_or_fobj, encoding='utf-8', *args, **kwargs): return create_table([field_names] + table_rows, meta=meta, *args, **kwargs) +import_from_json.is_lazy = False + + def _convert(value, field_type, *args, **kwargs): if value is None or field_type in ( fields.BinaryField, @@ -74,6 +78,8 @@ def export_to_json(table, filename_or_fobj=None, encoding='utf-8', indent=None, fields = table.fields prepared_table = prepare_to_export(table, *args, **kwargs) field_names = next(prepared_table) + + # TODO: could be lazy so we don't need to store the whole table into memory data = [{field_name: _convert(value, fields[field_name], *args, **kwargs) for field_name, value in zip(field_names, row)} for row in prepared_table] diff --git a/rows/plugins/plugin_parquet.py b/rows/plugins/plugin_parquet.py index 13cb0155..eda7edc8 100644 --- a/rows/plugins/plugin_parquet.py +++ b/rows/plugins/plugin_parquet.py @@ -52,8 +52,12 @@ def import_from_parquet(filename_or_fobj, *args, **kwargs): for schema in parquet._read_footer(fobj).schema if schema.type is not None]) header = list(types.keys()) - table_rows = list(parquet.reader(fobj)) # TODO: be lazy + # TODO: make it lazy + table_rows = list(parquet.reader(fobj)) meta = {'imported_from': 'parquet', 'filename': filename,} return create_table([header] + table_rows, meta=meta, force_types=types, *args, **kwargs) + + +import_from_parquet.is_lazy = False diff --git a/rows/plugins/sqlite.py b/rows/plugins/sqlite.py index 2d179e83..e3e87464 100644 --- a/rows/plugins/sqlite.py +++ b/rows/plugins/sqlite.py @@ -21,6 +21,8 @@ import sqlite3 import string +from itertools import chain + import six import rows.fields as fields @@ -29,6 +31,7 @@ prepare_to_export) SQL_TABLE_NAMES = 'SELECT name FROM sqlite_master WHERE type="table"' +# TODO: may use query args instead of string formatting SQL_CREATE_TABLE = 'CREATE TABLE IF NOT EXISTS "{table_name}" ({field_types})' SQL_SELECT_ALL = 'SELECT * FROM "{table_name}"' SQL_INSERT = 'INSERT INTO "{table_name}" ({field_names}) VALUES ({placeholders})' @@ -122,13 +125,15 @@ def import_from_sqlite(filename_or_connection, table_name='table1', query=None, if query_args is None: query_args = tuple() - table_rows = list(cursor.execute(query, query_args)) # TODO: may be lazy - header = [six.text_type(info[0]) for info in cursor.description] - cursor.close() - # TODO: should close connection also? + cursor.execute(query, query_args) + data = chain([[six.text_type(info[0]) for info in cursor.description]], + cursor) meta = {'imported_from': 'sqlite', 'filename': filename_or_connection, } - return create_table([header] + table_rows, meta=meta, *args, **kwargs) + return create_table(data, meta=meta, *args, **kwargs) + + +import_from_sqlite.is_lazy = True def export_to_sqlite(table, filename_or_connection, table_name=None, diff --git a/rows/plugins/txt.py b/rows/plugins/txt.py index cd4d0cd8..648f85ff 100644 --- a/rows/plugins/txt.py +++ b/rows/plugins/txt.py @@ -175,6 +175,9 @@ def import_from_txt(filename_or_fobj, encoding='utf-8', return create_table(table_rows, meta=meta, *args, **kwargs) +import_from_txt.is_lazy = False + + def export_to_txt(table, filename_or_fobj=None, encoding=None, frame_style="ASCII", safe_none_frame=True, *args, **kwargs): """Export a `rows.Table` to text. diff --git a/rows/plugins/utils.py b/rows/plugins/utils.py index a2d4f1f1..4ad6ed99 100644 --- a/rows/plugins/utils.py +++ b/rows/plugins/utils.py @@ -28,7 +28,7 @@ from collections.abc import Iterator from rows.fields import detect_types -from rows.table import FlexibleTable, Table +from rows.table import FlexibleTable, Table, LazyTable SLUG_CHARS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_' @@ -134,10 +134,25 @@ def make_header(field_names, permit_not=False): return result +def get_row_data(full_field_names, field_names): + + field_indexes = [full_field_names.index(field_name) + for field_name in field_names] + + def func(rows_data): + for row_data in rows_data: + yield [row_data[field_index] for field_index in field_indexes] + + return func + + def create_table(data, meta=None, fields=None, skip_header=True, import_fields=None, samples=None, force_types=None, - *args, **kwargs): + lazy=False, *args, **kwargs): + # TODO: change samples to be a fixed number + # TODO: may change samples logic (`float('inf')` or `all`) # TODO: add auto_detect_types=True parameter + table_rows = iter(data) sample_rows = [] @@ -159,6 +174,9 @@ def create_table(data, meta=None, fields=None, skip_header=True, if not isinstance(fields, OrderedDict): raise ValueError('`fields` must be an `OrderedDict`') + # TODO: if `fields` is set, we're going to have the wrong order, + # compared to the first row (header). + if skip_header: next(table_rows) @@ -181,26 +199,38 @@ def create_table(data, meta=None, fields=None, skip_header=True, new_fields[field_name] = fields[field_name] fields = new_fields - table = Table(fields=fields, meta=meta) - # TODO: put this inside Table.__init__ - for row in chain(sample_rows, table_rows): - table.append({field_name: value - for field_name, value in zip(header, row)}) + if not lazy: + table = Table(fields=fields, meta=meta) + + # TODO: put this inside Table.__init__ + for row in chain(sample_rows, table_rows): + table.append({field_name: value + for field_name, value in zip(header, row)}) + + else: + data = chain(sample_rows, table_rows) + field_names = fields.keys() + + if header != field_names: + rows_data = get_row_data(header, field_names) + data = chain(rows_data(sample_rows), rows_data(table_rows)) + + table = LazyTable(fields=fields, data=data, meta=meta) return table def prepare_to_export(table, export_fields=None, *args, **kwargs): # TODO: optimize for more used cases (export_fields=None) + + # TODO: may create `BaseTable` and use `isinstance` instead table_type = type(table) - if table_type not in (FlexibleTable, Table): + if table_type not in (FlexibleTable, Table, LazyTable): raise ValueError('Table type not recognized') - if export_fields is None: - # we use already slugged-fieldnames + if export_fields is None: # Table has slugged fieldnames already export_fields = table.field_names - else: - # we need to slug all the field names + else: # Need to slug all the field names before exporting export_fields = make_header(export_fields) table_field_names = table.field_names @@ -211,6 +241,7 @@ def prepare_to_export(table, export_fields=None, *args, **kwargs): yield export_fields + # TODO: create a standard API on all `Table` classes if table_type is Table: field_indexes = list(map(table_field_names.index, export_fields)) for row in table._rows: @@ -218,6 +249,9 @@ def prepare_to_export(table, export_fields=None, *args, **kwargs): elif table_type is FlexibleTable: for row in table._rows: yield [row[field_name] for field_name in export_fields] + elif table_type is LazyTable: + for row in table: + yield [getattr(row, field_name) for field_name in export_fields] def serialize(table, *args, **kwargs): diff --git a/rows/plugins/xls.py b/rows/plugins/xls.py index acf7b06d..eb089dd5 100644 --- a/rows/plugins/xls.py +++ b/rows/plugins/xls.py @@ -163,6 +163,9 @@ def import_from_xls(filename_or_fobj, sheet_name=None, sheet_index=0, return create_table(table_rows, meta=meta, *args, **kwargs) +import_from_xls.is_lazy = False + + def export_to_xls(table, filename_or_fobj=None, sheet_name='Sheet1', *args, **kwargs): """Export the rows.Table to XLS file and return the saved file.""" diff --git a/rows/plugins/xlsx.py b/rows/plugins/xlsx.py index 5ca95495..7efad9d6 100644 --- a/rows/plugins/xlsx.py +++ b/rows/plugins/xlsx.py @@ -79,12 +79,16 @@ def import_from_xlsx(filename_or_fobj, sheet_name=None, sheet_index=0, for row_index in range(start_row + 1, end_row + 2)] filename, _ = get_filename_and_fobj(filename_or_fobj, dont_open=True) + metadata = {'imported_from': 'xlsx', 'filename': filename, 'sheet_name': sheet_name, } return create_table(table_rows, meta=metadata, *args, **kwargs) +import_from_xlsx.is_lazy = False + + FORMATTING_STYLES = { fields.DateField: 'YYYY-MM-DD', fields.DatetimeField: 'YYYY-MM-DD HH:MM:SS', diff --git a/rows/plugins/xpath.py b/rows/plugins/xpath.py index 1a89ae55..825cc9e2 100644 --- a/rows/plugins/xpath.py +++ b/rows/plugins/xpath.py @@ -68,6 +68,7 @@ def import_from_xpath(filename_or_fobj, rows_xpath, fields_xpath, filename, fobj = get_filename_and_fobj(filename_or_fobj, mode='rb') xml = fobj.read().decode(encoding) + # TODO: make it lazy (is it possible with lxml?) tree = tree_from_string(xml) row_elements = tree.xpath(rows_xpath) @@ -79,3 +80,6 @@ def import_from_xpath(filename_or_fobj, rows_xpath, fields_xpath, 'filename': filename, 'encoding': encoding,} return create_table([header] + result_rows, meta=meta, *args, **kwargs) + + +import_from_xpath.is_lazy = False diff --git a/rows/table.py b/rows/table.py index 156961a1..418b0ca2 100644 --- a/rows/table.py +++ b/rows/table.py @@ -28,6 +28,39 @@ from collections.abc import MutableSequence, Sized +class LazyTable(object): + + def __init__(self, fields, data, meta=None): + self.fields = OrderedDict(fields) + + self.Row = namedtuple('Row', self.field_names) + self.meta = dict(meta) if meta is not None else {} + self._rows = data + + @property + def field_names(self): + return list(self.fields.keys()) + + @property + def field_types(self): + return list(self.fields.values()) + + def __repr__(self): + imported = '' + if 'imported_from' in self.meta: + imported = ' (from {})'.format(self.meta['imported_from']) + + return ''.format( + imported, len(self.fields)) + + def __iter__(self): + fields = list(self.fields.items()) + for row in self._rows: + yield self.Row(*[field_type.deserialize(value) + for value, (field_name, field_type) in + zip(row, fields)]) + + class Table(MutableSequence): def __init__(self, fields, meta=None): diff --git a/tests/tests_plugin_csv.py b/tests/tests_plugin_csv.py index 4c3cdc49..fe0dedeb 100644 --- a/tests/tests_plugin_csv.py +++ b/tests/tests_plugin_csv.py @@ -56,6 +56,7 @@ def test_imports(self): rows.plugins.plugin_csv.import_from_csv) self.assertIs(rows.export_to_csv, rows.plugins.plugin_csv.export_to_csv) + self.assertTrue(rows.import_from_csv.is_lazy) @mock.patch('rows.plugins.plugin_csv.create_table') def test_import_from_csv_uses_create_table(self, mocked_create_table): @@ -325,3 +326,30 @@ def test_export_callback(self): [x[0][0] for x in myfunc.call_args_list], [3, 6, 9, 10] ) + + def test_import_from_csv_is_lazy(self): + temp = tempfile.NamedTemporaryFile(delete=False) + filename = '{}.{}'.format(temp.name, self.file_extension) + self.files_to_delete.append(filename) + encoding = 'utf-8' + number_of_rows = 1000 + + fobj = open(filename, mode='wb+') + fobj.write('field1,field2\r\n'.encode(encoding)) + for index in range(number_of_rows): + row_data = ','.join([str(index), str(index ** 2)]) + '\r\n' + fobj.write(row_data.encode(encoding)) + fobj.flush() + total_bytes = fobj.tell() + + fobj.seek(0) + table = rows.import_from_csv(fobj, + encoding=encoding, + dialect=csv.excel, + samples=1, # pre-read only the first row + lazy=True) + self.assertEqual(fobj.tell(), 20) # 20 = len(1st line) + len(2nd line) + + data = list(table) + self.assertEqual(len(data), number_of_rows) + self.assertEqual(fobj.tell(), total_bytes) diff --git a/tests/tests_plugin_dicts.py b/tests/tests_plugin_dicts.py index 47d97fca..0309bd11 100644 --- a/tests/tests_plugin_dicts.py +++ b/tests/tests_plugin_dicts.py @@ -45,6 +45,7 @@ class PluginDictTestCase(utils.RowsTestMixIn, unittest.TestCase): def test_imports(self): self.assertIs(rows.import_from_dicts, rows.plugins.dicts.import_from_dicts) self.assertIs(rows.export_to_dicts, rows.plugins.dicts.export_to_dicts) + self.assertTrue(rows.import_from_dicts.is_lazy) @mock.patch("rows.plugins.dicts.create_table") def test_import_from_dicts_uses_create_table(self, mocked_create_table): diff --git a/tests/tests_plugin_html.py b/tests/tests_plugin_html.py index 81ede706..694a81f7 100644 --- a/tests/tests_plugin_html.py +++ b/tests/tests_plugin_html.py @@ -52,6 +52,7 @@ def test_imports(self): self.assertIs(rows.import_from_html, rows.plugins.plugin_html.import_from_html) self.assertIs(rows.export_to_html, rows.plugins.plugin_html.export_to_html) + self.assertFalse(rows.import_from_html.is_lazy) def test_import_from_html_filename(self): table = rows.import_from_html(self.filename, encoding=self.encoding) @@ -87,7 +88,7 @@ def test_import_from_html_uses_create_table(self, mocked_create_table): call = mocked_create_table.call_args kwargs['meta'] = {'imported_from': 'html', 'filename': self.filename, - 'encoding': 'iso-8859-1',} + 'encoding': 'iso-8859-1', } self.assertEqual(call[1], kwargs) def test_export_to_html_filename(self): diff --git a/tests/tests_plugin_json.py b/tests/tests_plugin_json.py index 437a2c3b..b01ab42a 100644 --- a/tests/tests_plugin_json.py +++ b/tests/tests_plugin_json.py @@ -42,6 +42,7 @@ def test_imports(self): rows.plugins.plugin_json.import_from_json) self.assertIs(rows.export_to_json, rows.plugins.plugin_json.export_to_json) + self.assertFalse(rows.import_from_json.is_lazy) @mock.patch('rows.plugins.plugin_json.create_table') def test_import_from_json_uses_create_table(self, mocked_create_table): diff --git a/tests/tests_plugin_ods.py b/tests/tests_plugin_ods.py index 66842c7a..cd60528f 100644 --- a/tests/tests_plugin_ods.py +++ b/tests/tests_plugin_ods.py @@ -36,6 +36,7 @@ class PluginOdsTestCase(utils.RowsTestMixIn, unittest.TestCase): def test_imports(self): self.assertIs(rows.import_from_ods, rows.plugins.ods.import_from_ods) + self.assertFalse(rows.import_from_ods.is_lazy) @mock.patch('rows.plugins.ods.create_table') def test_import_from_ods_uses_create_table(self, mocked_create_table): diff --git a/tests/tests_plugin_parquet.py b/tests/tests_plugin_parquet.py index c89b47cc..6daeea5e 100644 --- a/tests/tests_plugin_parquet.py +++ b/tests/tests_plugin_parquet.py @@ -61,6 +61,7 @@ class PluginParquetTestCase(unittest.TestCase): def test_imports(self): self.assertIs(rows.import_from_parquet, rows.plugins.plugin_parquet.import_from_parquet) + self.assertFalse(rows.import_from_parquet.is_lazy) @mock.patch('rows.plugins.plugin_parquet.create_table') def test_import_from_parquet_uses_create_table(self, mocked_create_table): diff --git a/tests/tests_plugin_sqlite.py b/tests/tests_plugin_sqlite.py index 8ed18f82..ad4e0dee 100644 --- a/tests/tests_plugin_sqlite.py +++ b/tests/tests_plugin_sqlite.py @@ -50,6 +50,7 @@ def test_imports(self): rows.plugins.sqlite.import_from_sqlite) self.assertIs(rows.export_to_sqlite, rows.plugins.sqlite.export_to_sqlite) + self.assertTrue(rows.import_from_sqlite.is_lazy) @mock.patch('rows.plugins.sqlite.create_table') def test_import_from_sqlite_uses_create_table(self, mocked_create_table): @@ -101,6 +102,24 @@ def test_export_to_sqlite_filename(self): table = rows.import_from_sqlite(temp.name) self.assert_table_equal(table, utils.table) + def test_import_from_sqlite_is_lazy(self): #, mocked_create_table): + connection = mock.MagicMock() + cursor = connection.cursor() + cursor.description = [('f1', None), ('f2', None), ('f3', None)] + gen = utils.LazyGenerator(max_number=1000) + igen = iter(gen) + next(igen) # get header out -- does not happen on SQLite + cursor.__iter__.return_value = igen + cursor.fetchall = lambda: list(gen) + + table = rows.import_from_sqlite(connection, lazy=True, samples=50) + self.assertIs(gen.last, 49) + + for index, _ in enumerate(table): + if index == 99: + break + self.assertIs(gen.last, 99) + def test_export_to_sqlite_connection(self): # TODO: may test file contents temp = tempfile.NamedTemporaryFile(delete=False, mode='wb') @@ -124,9 +143,9 @@ def test_export_to_sqlite_create_unique_table_name(self): rows.export_to_sqlite(second_table, temp.name) # table2 result_first_table = rows.import_from_sqlite(temp.name, - table_name='table1') + table_name='table1') result_second_table = rows.import_from_sqlite(temp.name, - table_name='table2') + table_name='table2') self.assert_table_equal(result_first_table, first_table) self.assert_table_equal(result_second_table, second_table) diff --git a/tests/tests_plugin_txt.py b/tests/tests_plugin_txt.py index 442a0e92..9e3cbd4c 100644 --- a/tests/tests_plugin_txt.py +++ b/tests/tests_plugin_txt.py @@ -41,6 +41,7 @@ class PluginTxtTestCase(utils.RowsTestMixIn, unittest.TestCase): def test_imports(self): self.assertIs(rows.import_from_txt, rows.plugins.txt.import_from_txt) self.assertIs(rows.export_to_txt, rows.plugins.txt.export_to_txt) + self.assertFalse(rows.import_from_txt.is_lazy) @mock.patch('rows.plugins.txt.create_table') def test_import_from_txt_uses_create_table(self, mocked_create_table): diff --git a/tests/tests_plugin_utils.py b/tests/tests_plugin_utils.py index b9798861..80f729e2 100644 --- a/tests/tests_plugin_utils.py +++ b/tests/tests_plugin_utils.py @@ -29,6 +29,9 @@ import rows import rows.plugins.utils as plugins_utils +from rows import fields +from rows.table import LazyTable + import tests.utils as utils from rows import fields @@ -172,6 +175,45 @@ def test_create_table_force_types(self): for field_name, field_type in force_types.items(): self.assertEqual(table.fields[field_name], field_type) + def test_create_table_returns_LazyTable(self): + header = ['field1', 'field2', 'field3'] + table_rows = [['1', '3.14', 'Álvaro'], + ['2', '2.71', 'turicas'], + ['3', '1.23', 'Justen']] + force_types = {'field2': rows.fields.DecimalField} + + table = plugins_utils.create_table([header] + table_rows, + force_types=force_types, lazy=True) + self.assertTrue(isinstance(table, LazyTable)) + + def test_create_table_LazyTable_import_fields(self): + max_number = 1000 + data = utils.LazyGenerator(max_number) + import_fields = ['number_double', 'number_sq'] + + table = plugins_utils.create_table(data, + import_fields=import_fields, + samples=None, lazy=True) + self.assertEqual(table.field_names, import_fields) + + expected_data = [{'number_double': number * 2, + 'number_sq': number ** 2} + for number in range(max_number)] + data = [dict(row._asdict()) for row in table] + self.assertEqual(data, expected_data) + + def test_create_table_sample_size(self): + max_number = 1000 + samples = 200 + + data = utils.LazyGenerator(max_number) + table = plugins_utils.create_table(data, lazy=True, samples=samples) + self.assertEqual(data.last, samples - 1) + + data = utils.LazyGenerator(max_number) + table = plugins_utils.create_table(data, samples=samples) + self.assertEqual(data.last, max_number - 1) + def test_prepare_to_export_all_fields(self): result = plugins_utils.prepare_to_export(utils.table, export_fields=None) diff --git a/tests/tests_plugin_xls.py b/tests/tests_plugin_xls.py index c40f64e0..233597b3 100644 --- a/tests/tests_plugin_xls.py +++ b/tests/tests_plugin_xls.py @@ -44,6 +44,7 @@ class PluginXlsTestCase(utils.RowsTestMixIn, unittest.TestCase): def test_imports(self): self.assertIs(rows.import_from_xls, rows.plugins.xls.import_from_xls) self.assertIs(rows.export_to_xls, rows.plugins.xls.export_to_xls) + self.assertFalse(rows.import_from_xls.is_lazy) @mock.patch('rows.plugins.xls.create_table') def test_import_from_xls_uses_create_table(self, mocked_create_table): diff --git a/tests/tests_plugin_xlsx.py b/tests/tests_plugin_xlsx.py index 701c72b9..08fe85fa 100644 --- a/tests/tests_plugin_xlsx.py +++ b/tests/tests_plugin_xlsx.py @@ -43,6 +43,7 @@ def test_imports(self): rows.plugins.xlsx.import_from_xlsx) self.assertIs(rows.export_to_xlsx, rows.plugins.xlsx.export_to_xlsx) + self.assertFalse(rows.import_from_xlsx.is_lazy) @mock.patch('rows.plugins.xlsx.create_table') def test_import_from_xlsx_uses_create_table(self, mocked_create_table): diff --git a/tests/tests_plugin_xpath.py b/tests/tests_plugin_xpath.py index e5680e5f..1b5fbf1f 100644 --- a/tests/tests_plugin_xpath.py +++ b/tests/tests_plugin_xpath.py @@ -106,9 +106,9 @@ def test_import_from_xpath_unescape_and_extract_text(self): fields_xpath = OrderedDict([('name', './/text()'), ('link', './/a/@href')]) table = rows.import_from_xpath(BytesIO(html), + encoding='utf-8', rows_xpath=rows_xpath, - fields_xpath=fields_xpath, - encoding='utf-8') + fields_xpath=fields_xpath) self.assertEqual(table[0].name, 'Abadia de Goiás (GO)') self.assertEqual(table[1].name, 'Abadiânia (GO)') diff --git a/tests/tests_table.py b/tests/tests_table.py index a901ab31..3183601e 100644 --- a/tests/tests_table.py +++ b/tests/tests_table.py @@ -26,7 +26,7 @@ import rows import rows.fields as fields -from rows.table import FlexibleTable, Table +from rows.table import FlexibleTable, LazyTable, Table binary_type_name = six.binary_type.__name__ @@ -322,6 +322,56 @@ def test_table_add_should_not_iterate_over_rows(self): self.assertFalse(table2._rows.__iter__.called) +class LazyTableTestCase(unittest.TestCase): + + def setUp(self): + fields = {'name': rows.fields.TextField, + 'birthdate': rows.fields.DateField, } + data_rows = [ + ('Álvaro Justen', datetime.date(1987, 4, 29)), + ('Somebody', datetime.date(1990, 2, 1)), + ('Douglas Adams', datetime.date(1952, 3, 11)), ] + data = (row for row in data_rows) + self.table = LazyTable(fields=fields, data=data) + + def test_LazyTable_is_present_on_main_namespace(self): + self.assertIn('LazyTable', dir(rows)) + self.assertIs(LazyTable, rows.LazyTable) + + def test_table_iteration(self): + table_rows = list(self.table) + + self.assertEqual(list(self.table), []) + self.assertEqual(len(table_rows), 3) + + self.assertEqual(table_rows[0].name, 'Álvaro Justen') + self.assertEqual(table_rows[0].birthdate, datetime.date(1987, 4, 29)) + self.assertEqual(table_rows[1].name, 'Somebody') + self.assertEqual(table_rows[1].birthdate, datetime.date(1990, 2, 1)) + self.assertEqual(table_rows[2].name, 'Douglas Adams') + self.assertEqual(table_rows[2].birthdate, datetime.date(1952, 3, 11)) + + def test_table_slicing_error(self): + with self.assertRaises(TypeError) as context_manager: + self.table[0] + + def test_field_names_and_types(self): + self.assertEqual(self.table.field_names, + list(self.table.fields.keys())) + self.assertEqual(self.table.field_types, + list(self.table.fields.values())) + + def test_table_repr(self): + expected = '' + self.assertEqual(expected, repr(self.table)) + + # TODO: may be able to setitem (column) + # TODO: may be able to delitem (column) + # TODO: may be able to getitem (column) + # TODO: may be able to append rows (but not insert) + # TODO: should it have __add__? + + class TestFlexibleTable(unittest.TestCase): def setUp(self):