diff --git a/rows/plugins/plugin_csv.py b/rows/plugins/plugin_csv.py index bcd1904a..0a1d4228 100644 --- a/rows/plugins/plugin_csv.py +++ b/rows/plugins/plugin_csv.py @@ -17,7 +17,7 @@ from __future__ import unicode_literals -from io import BytesIO +from io import open as io_open, BytesIO, BufferedReader import six import unicodecsv @@ -27,8 +27,19 @@ sniffer = unicodecsv.Sniffer() + if six.PY2: + class NotNullBytesWrapper(BufferedReader): + + def read(self, *args, **kwargs): + data = super(NotNullBytesWrapper, self).read(*args, **kwargs) + return data.replace(b'\x00', b'') + + def readline(self, *args, **kwargs): + data = super(NotNullBytesWrapper, self).readline(*args, **kwargs) + return data.replace(b'\x00', b'') + def discover_dialect(sample, encoding=None, delimiters=(b',', b';', b'\t', b'|')): """Discover a CSV dialect based on a sample size @@ -49,6 +60,16 @@ def discover_dialect(sample, encoding=None, elif six.PY3: + class NotNullBytesWrapper(BufferedReader): + + def read(self, *args, **kwargs): + data = super().read(*args, **kwargs) + return data.replace(b'\x00', b'') + + def readline(self, *args, **kwargs): + data = super().readline(*args, **kwargs) + return data.replace(b'\x00', b'') + def discover_dialect(sample, encoding, delimiters=(',', ';', '\t', '|')): """Discover a CSV dialect based on a sample size @@ -105,6 +126,11 @@ def import_from_csv(filename_or_fobj, encoding='utf-8', dialect=None, filename, fobj = get_filename_and_fobj(filename_or_fobj, mode='rb') + if six.PY2: + fobj = NotNullBytesWrapper(io_open(filename, mode='rb')) + elif six.PY3: + fobj = NotNullBytesWrapper(fobj) + if dialect is None: dialect = discover_dialect(sample=read_sample(fobj, sample_size), encoding=encoding) @@ -136,7 +162,7 @@ def export_to_csv(table, filename_or_fobj=None, encoding='utf-8', else: fobj = BytesIO() - # TODO: may use `io.BufferedWriter` instead of `ipartition` so user can + # TODO: may use `BufferedWriter` instead of `ipartition` so user can # choose the real size (in Bytes) when to flush to the file system, instead # number of rows writer = unicodecsv.writer(fobj, encoding=encoding, dialect=dialect) diff --git a/rows/plugins/utils.py b/rows/plugins/utils.py index d5c4181d..599f1c64 100644 --- a/rows/plugins/utils.py +++ b/rows/plugins/utils.py @@ -17,6 +17,7 @@ from __future__ import unicode_literals +import io from collections import Iterator, OrderedDict from itertools import chain, islice from unicodedata import normalize @@ -82,13 +83,24 @@ def ipartition(iterable, partition_size): yield data -def get_filename_and_fobj(filename_or_fobj, mode='r', dont_open=False): - if getattr(filename_or_fobj, 'read', None) is not None: +def get_filename_and_fobj(filename_or_fobj, mode='r', dont_open=False, **kwargs): + + # TODO: what if fobj is passed, using a different mode from `mode`? + + if getattr(filename_or_fobj, 'read', None) is not None: # file-like object + filename = getattr(filename_or_fobj, 'name', None) fobj = filename_or_fobj - filename = getattr(fobj, 'name', None) - else: - fobj = open(filename_or_fobj, mode=mode) if not dont_open else None + try: + file_number = fobj.fileno() + except io.UnsupportedOperation: + # Another kind of file object, like `io.BytesIO` + fobj = io.BufferedReader(fobj, **kwargs) # TODO: pass mode + else: # Regular file + fobj = io.open(file_number, mode=mode, **kwargs) + + else: # filename filename = filename_or_fobj + fobj = io.open(filename_or_fobj, mode=mode, **kwargs) if not dont_open else None return filename, fobj diff --git a/tests/data/csv_with_null_bytes.csv b/tests/data/csv_with_null_bytes.csv new file mode 100644 index 00000000..a43b6952 Binary files /dev/null and b/tests/data/csv_with_null_bytes.csv differ diff --git a/tests/tests_plugin_csv.py b/tests/tests_plugin_csv.py index 4c3cdc49..e501a721 100644 --- a/tests/tests_plugin_csv.py +++ b/tests/tests_plugin_csv.py @@ -325,3 +325,10 @@ def test_export_callback(self): [x[0][0] for x in myfunc.call_args_list], [3, 6, 9, 10] ) + + def test_issue_273(self): + filename = 'tests/data/csv_with_null_bytes.csv' + # Should not raise Error: line contains NULL byte + table = rows.import_from_csv(filename, encoding='latin-1') + + self.assertEqual(len(table), 9) diff --git a/tests/tests_plugin_utils.py b/tests/tests_plugin_utils.py index b9798861..576ca19a 100644 --- a/tests/tests_plugin_utils.py +++ b/tests/tests_plugin_utils.py @@ -17,6 +17,7 @@ from __future__ import unicode_literals +import io import itertools import random import tempfile @@ -33,6 +34,8 @@ from rows import fields +get_filename_and_fobj = plugins_utils.get_filename_and_fobj + class GenericUtilsTestCase(unittest.TestCase): def test_slug(self): @@ -339,7 +342,62 @@ def test_export_data(self): result = plugins_utils.export_data(filename_or_fobj, data) self.assertIs(result, data) + +class FilenameFObjTestCase(unittest.TestCase): + # TODO: test other features of this function (example: BytesIO should + # return filename = None) + + def setUp(self): + self.filename = 'tests/data/csv_with_null_bytes.csv' + self.encoding = 'latin1' + self.data = io.open(self.filename, mode='rb').read() + self.decoded_data = self.data.decode(self.encoding) + + def test_get_filename_and_fobj_passing_filename(self): + mode = 'rb' + _, f = get_filename_and_fobj(self.filename, mode=mode) + self.assertTrue(hasattr(f, 'readable') and f.readable()) + self.assertEqual(f.mode, mode) + self.assertEqual(f.read(), self.data) + + def test_get_filename_and_fobj_passing_text_fobj(self): + if six.PY3: + mode = 'r' + fobj = open(self.filename, encoding=self.encoding) + _, f = get_filename_and_fobj(fobj, mode=mode, encoding=self.encoding) + self.assertTrue(hasattr(f, 'readable') and f.readable()) + self.assertEqual(f.mode, mode) + self.assertEqual(f.read(), self.decoded_data) + + mode = 'r' + fobj = io.open(self.filename, encoding=self.encoding) + _, f = get_filename_and_fobj(fobj, mode=mode, encoding=self.encoding) + self.assertTrue(hasattr(f, 'readable') and f.readable()) + self.assertEqual(f.mode, mode) + self.assertEqual(f.read(), self.decoded_data) + + def test_get_filename_and_fobj_passing_bytes_fobj(self): + mode = 'rb' + fobj = open(self.filename, mode=mode) + _, f = get_filename_and_fobj(fobj, mode=mode) + self.assertTrue(hasattr(f, 'readable') and f.readable()) + self.assertEqual(f.mode, mode) + self.assertEqual(f.read(), self.data) + + fobj = io.open(self.filename, mode=mode) + _, f = get_filename_and_fobj(fobj, mode=mode) + self.assertTrue(hasattr(f, 'readable') and f.readable()) + self.assertEqual(f.mode, mode) + self.assertEqual(f.read(), self.data) + + def test_get_filename_and_fobj_passing_BytesIO(self): + mode = 'rb' + fobj = io.BytesIO(self.data) + _, f = get_filename_and_fobj(fobj, mode=mode) + self.assertTrue(hasattr(f, 'readable') and f.readable()) + self.assertEqual(f.mode, mode) + self.assertEqual(f.read(), self.data) + # TODO: test make_header # TODO: test all features of create_table # TODO: test if error is raised if len(row) != len(fields) - # TODO: test get_fobj_and_filename (BytesIO should return filename = None)