Source code for pygtfs.feed

from __future__ import (division, absolute_import, print_function,
                        unicode_literals)

import os
import io
import csv

from collections import namedtuple
from zipfile import ZipFile

import six


def _row_stripper(row):
    return (cell.strip() for cell in row)


[docs]class CSV(object): """A CSV file.""" def __init__(self, rows, feedtype='CSVTuple', columns=None): header = list(six.next(rows)) # deal with annoying unnecessary boms on utf-8 header[0] = header[0].lstrip("\ufeff") if not columns: raise ValueError('missing columns argument') # we need to filter fields that exist in the csv but not our model. self.cols = tuple(i for i, h in enumerate(header) if h in columns) if len(self.cols) == len(header): # There is no actual filtering, we can skip it self.cols = None self.Tuple = namedtuple(feedtype, self._pick_columns(header)) self.rows = rows def __repr__(self): return '<CSV %s>' % self.header def __iter__(self): return self def __next__(self): n = tuple(six.next(self.rows)) if n: return self.Tuple._make(self._pick_columns(n)) next = __next__ # python 2 compatible def _pick_columns(self, row): if self.cols: return (row[x] for x in self.cols) return row
[docs]class Feed(object): """A collection of CSV files with headers, either zipped into an archive or loose in a folder.""" def __init__(self, filename, strip_fields=True): self.filename = filename self.feed_name = derive_feed_name(filename) self.zf = None self.strip_fields = strip_fields if not os.path.isdir(filename): self.zf = ZipFile(filename) if six.PY2: self.reader = self.python2_reader else: self.reader = self.python3_reader def __repr__(self): return '<Feed %s>' % self.filename
[docs] def python2_reader(self, filename): if self.zf: try: binary_file_handle = self.zf.open(filename, 'rU') except IOError: raise IOError('%s is not present in feed' % filename) else: binary_file_handle = open(os.path.join(self.filename, filename), "rb") reader = csv.reader(binary_file_handle) for row in reader: yield [six.text_type(x, 'utf-8') for x in row]
[docs] def python3_reader(self, filename): if self.zf: try: text_file_handle = io.TextIOWrapper( self.zf.open(filename, "r"), encoding="utf-8") except IOError: raise IOError('%s is not present in feed' % filename) else: text_file_handle = open(os.path.join(self.filename, filename), "r", encoding="utf-8") return csv.reader(text_file_handle)
[docs] def read_table(self, filename, columns): if self.strip_fields: rows = (_row_stripper(row) for row in self.reader(filename)) else: rows = self.reader(filename) feedtype = filename.rsplit('/')[-1].rsplit('.')[0].title().replace('_', '') return CSV(feedtype=feedtype, rows=rows, columns=columns)
[docs]def derive_feed_name(filename): return os.path.basename(filename.rstrip('/'))