Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions rows/plugins/plugin_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from __future__ import unicode_literals

from io import BytesIO
from io import open as io_open, BytesIO, BufferedReader

import six
import unicodecsv
Expand All @@ -27,8 +27,19 @@

sniffer = unicodecsv.Sniffer()


if six.PY2:

class NotNullBytesWrapper(BufferedReader):

def read(self, *args, **kwargs):
data = super(NotNullBytesWrapper, self).read(*args, **kwargs)
return data.replace(b'\x00', b'')

def readline(self, *args, **kwargs):
data = super(NotNullBytesWrapper, self).readline(*args, **kwargs)
return data.replace(b'\x00', b'')

def discover_dialect(sample, encoding=None,
delimiters=(b',', b';', b'\t', b'|')):
"""Discover a CSV dialect based on a sample size
Expand All @@ -49,6 +60,16 @@ def discover_dialect(sample, encoding=None,

elif six.PY3:

class NotNullBytesWrapper(BufferedReader):

def read(self, *args, **kwargs):
data = super().read(*args, **kwargs)
return data.replace(b'\x00', b'')

def readline(self, *args, **kwargs):
data = super().readline(*args, **kwargs)
return data.replace(b'\x00', b'')

def discover_dialect(sample, encoding, delimiters=(',', ';', '\t', '|')):
"""Discover a CSV dialect based on a sample size

Expand Down Expand Up @@ -105,6 +126,11 @@ def import_from_csv(filename_or_fobj, encoding='utf-8', dialect=None,

filename, fobj = get_filename_and_fobj(filename_or_fobj, mode='rb')

if six.PY2:
fobj = NotNullBytesWrapper(io_open(filename, mode='rb'))
elif six.PY3:
fobj = NotNullBytesWrapper(fobj)

if dialect is None:
dialect = discover_dialect(sample=read_sample(fobj, sample_size),
encoding=encoding)
Expand Down Expand Up @@ -136,7 +162,7 @@ def export_to_csv(table, filename_or_fobj=None, encoding='utf-8',
else:
fobj = BytesIO()

# TODO: may use `io.BufferedWriter` instead of `ipartition` so user can
# TODO: may use `BufferedWriter` instead of `ipartition` so user can
# choose the real size (in Bytes) when to flush to the file system, instead
# number of rows
writer = unicodecsv.writer(fobj, encoding=encoding, dialect=dialect)
Expand Down
22 changes: 17 additions & 5 deletions rows/plugins/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import unicode_literals

import io
from collections import Iterator, OrderedDict
from itertools import chain, islice
from unicodedata import normalize
Expand Down Expand Up @@ -82,13 +83,24 @@ def ipartition(iterable, partition_size):
yield data


def get_filename_and_fobj(filename_or_fobj, mode='r', dont_open=False):
if getattr(filename_or_fobj, 'read', None) is not None:
def get_filename_and_fobj(filename_or_fobj, mode='r', dont_open=False, **kwargs):

# TODO: what if fobj is passed, using a different mode from `mode`?

if getattr(filename_or_fobj, 'read', None) is not None: # file-like object
filename = getattr(filename_or_fobj, 'name', None)
fobj = filename_or_fobj
filename = getattr(fobj, 'name', None)
else:
fobj = open(filename_or_fobj, mode=mode) if not dont_open else None
try:
file_number = fobj.fileno()
except io.UnsupportedOperation:
# Another kind of file object, like `io.BytesIO`
fobj = io.BufferedReader(fobj, **kwargs) # TODO: pass mode
else: # Regular file
fobj = io.open(file_number, mode=mode, **kwargs)

else: # filename
filename = filename_or_fobj
fobj = io.open(filename_or_fobj, mode=mode, **kwargs) if not dont_open else None

return filename, fobj

Expand Down
Binary file added tests/data/csv_with_null_bytes.csv
Binary file not shown.
7 changes: 7 additions & 0 deletions tests/tests_plugin_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,10 @@ def test_export_callback(self):
[x[0][0] for x in myfunc.call_args_list],
[3, 6, 9, 10]
)

def test_issue_273(self):
filename = 'tests/data/csv_with_null_bytes.csv'
# Should not raise Error: line contains NULL byte
table = rows.import_from_csv(filename, encoding='latin-1')

self.assertEqual(len(table), 9)
60 changes: 59 additions & 1 deletion tests/tests_plugin_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from __future__ import unicode_literals

import io
import itertools
import random
import tempfile
Expand All @@ -33,6 +34,8 @@
from rows import fields


get_filename_and_fobj = plugins_utils.get_filename_and_fobj

class GenericUtilsTestCase(unittest.TestCase):

def test_slug(self):
Expand Down Expand Up @@ -339,7 +342,62 @@ def test_export_data(self):
result = plugins_utils.export_data(filename_or_fobj, data)
self.assertIs(result, data)


class FilenameFObjTestCase(unittest.TestCase):
# TODO: test other features of this function (example: BytesIO should
# return filename = None)

def setUp(self):
self.filename = 'tests/data/csv_with_null_bytes.csv'
self.encoding = 'latin1'
self.data = io.open(self.filename, mode='rb').read()
self.decoded_data = self.data.decode(self.encoding)

def test_get_filename_and_fobj_passing_filename(self):
mode = 'rb'
_, f = get_filename_and_fobj(self.filename, mode=mode)
self.assertTrue(hasattr(f, 'readable') and f.readable())
self.assertEqual(f.mode, mode)
self.assertEqual(f.read(), self.data)

def test_get_filename_and_fobj_passing_text_fobj(self):
if six.PY3:
mode = 'r'
fobj = open(self.filename, encoding=self.encoding)
_, f = get_filename_and_fobj(fobj, mode=mode, encoding=self.encoding)
self.assertTrue(hasattr(f, 'readable') and f.readable())
self.assertEqual(f.mode, mode)
self.assertEqual(f.read(), self.decoded_data)

mode = 'r'
fobj = io.open(self.filename, encoding=self.encoding)
_, f = get_filename_and_fobj(fobj, mode=mode, encoding=self.encoding)
self.assertTrue(hasattr(f, 'readable') and f.readable())
self.assertEqual(f.mode, mode)
self.assertEqual(f.read(), self.decoded_data)

def test_get_filename_and_fobj_passing_bytes_fobj(self):
mode = 'rb'
fobj = open(self.filename, mode=mode)
_, f = get_filename_and_fobj(fobj, mode=mode)
self.assertTrue(hasattr(f, 'readable') and f.readable())
self.assertEqual(f.mode, mode)
self.assertEqual(f.read(), self.data)

fobj = io.open(self.filename, mode=mode)
_, f = get_filename_and_fobj(fobj, mode=mode)
self.assertTrue(hasattr(f, 'readable') and f.readable())
self.assertEqual(f.mode, mode)
self.assertEqual(f.read(), self.data)

def test_get_filename_and_fobj_passing_BytesIO(self):
mode = 'rb'
fobj = io.BytesIO(self.data)
_, f = get_filename_and_fobj(fobj, mode=mode)
self.assertTrue(hasattr(f, 'readable') and f.readable())
self.assertEqual(f.mode, mode)
self.assertEqual(f.read(), self.data)

# TODO: test make_header
# TODO: test all features of create_table
# TODO: test if error is raised if len(row) != len(fields)
# TODO: test get_fobj_and_filename (BytesIO should return filename = None)