Skip to content

Commit

Permalink
Add example_filter.py and adjust weekly start date function.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmkujala committed Apr 10, 2017
1 parent f4f44e8 commit f33ea0b
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 78 deletions.
21 changes: 20 additions & 1 deletion examples/example_filter.py
Original file line number Diff line number Diff line change
@@ -1 +1,20 @@
# TODO!
import datetime
import os

from example_import import load_or_import_example_gtfs
from gtfspy.gtfs import GTFS
from gtfspy.filter import FilterExtract

G = load_or_import_example_gtfs()
assert isinstance(G, GTFS)

filtered_database_path = "test_db_kuopio.week.sqlite"
if os.path.exists(filtered_database_path):
os.remove(filtered_database_path)

week_start = G.get_weekly_extract_start_date()
week_end = week_start + datetime.timedelta(days=7)
fe = FilterExtract(G, filtered_database_path, start_date=week_start, end_date=week_end)

fe.create_filtered_copy()
assert (os.path.exists(filtered_database_path))
8 changes: 5 additions & 3 deletions examples/example_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from gtfspy import osm_transfers


def load_or_import_gtfs(verbose=False):
def load_or_import_example_gtfs(verbose=False):
imported_database_path = "test_db_kuopio.sqlite"
if not os.path.exists(imported_database_path): # reimport only if the imported database does not already exist
print("Importing gtfs zip file")
Expand All @@ -16,6 +16,8 @@ def load_or_import_gtfs(verbose=False):

# Not this is an optional step, which is not necessary for many things.
print("Computing walking paths using OSM")
G = gtfs.GTFS(imported_database_path)
G.meta['download_date'] = "2017-03-15"

osm_path = "data/kuopio_extract_mapzen_2017_03_15.osm.pbf"

Expand All @@ -33,10 +35,10 @@ def load_or_import_gtfs(verbose=False):

if verbose:
print("Location name:" + G.get_location_name()) # should print Kuopio
print("Time span of the data in unixtime: " + str(G.get_conservative_gtfs_time_span_in_ut()))
print("Time span of the data in unixtime: " + str(G.get_approximate_schedule_time_span_in_ut()))
# prints the time span in unix time
return G


if __name__ == "__main__":
load_or_import_gtfs(verbose=True)
load_or_import_example_gtfs(verbose=True)
4 changes: 2 additions & 2 deletions examples/example_map_visualization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from gtfspy import mapviz
from example_import import load_or_import_gtfs
from example_import import load_or_import_example_gtfs
from matplotlib import pyplot as plt

g = load_or_import_gtfs()
g = load_or_import_example_gtfs()

ax = mapviz.plot_route_network(g)
ax = mapviz.plot_all_stops(g, ax)
Expand Down
41 changes: 41 additions & 0 deletions examples/example_plot_trip_counts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import functools
import os

from example_import import load_or_import_example_gtfs
from matplotlib import pyplot as plt
from gtfspy.gtfs import GTFS

G = load_or_import_example_gtfs()

daily_trip_counts = G.get_trip_counts_per_day()
f, ax = plt.subplots()

datetimes = [date.to_pydatetime() for date in daily_trip_counts['date']]
trip_counts = daily_trip_counts['trip_counts']

ax.bar(datetimes, trip_counts)
ax.axvline(G.meta['download_date'], color="red")
threshold = 0.96
ax.axhline(trip_counts.max() * threshold, color="red")
ax.axvline(G.get_weekly_extract_start_date(weekdays_at_least_of_max=threshold), color="yellow")

weekly_db_path = "test_db_kuopio.week.sqlite"
if os.path.exists(weekly_db_path):
G = GTFS(weekly_db_path)
f, ax = plt.subplots()
daily_trip_counts = G.get_trip_counts_per_day()
datetimes = [date.to_pydatetime() for date in daily_trip_counts['date']]
trip_counts = daily_trip_counts['trip_counts']
ax.bar(datetimes, trip_counts)

events = list(G.generate_routable_transit_events(0, G.get_approximate_schedule_time_span_in_ut()[0]))
min_ut = float('inf')
for e in events:
min_ut = min(e.dep_time_ut, min_ut)

print(G.get_approximate_schedule_time_span_in_ut())

plt.show()



2 changes: 1 addition & 1 deletion examples/example_temporal_distance_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from matplotlib import rc
import example_import

G = example_import.load_or_import_gtfs()
G = example_import.load_or_import_example_gtfs()

from_stop_name = "Ahkiotie 2 E"
to_stop_name = "Kauppahalli P"
Expand Down
14 changes: 8 additions & 6 deletions gtfspy/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from gtfspy import gtfs



class FilterExtract(object):

def __init__(self,
gtfs,
G,
copy_db_path,
buffer_distance=None,
buffer_lat=None,
Expand All @@ -28,9 +28,11 @@ def __init__(self,
agency_distance=None):
"""
Copy a database, and then based on various filters.
Only copy_and_filter method is provided as of now because we do not want to take the risk of
losing any data of the original databases.
Only method `create_filtered_copy` is provided as we do not want to take the risk of
losing the data stored in the original database.
G: gtfspy.gtfs.GTFS
the original database
copy_db_path : str
path to another database database
update_metadata : boolean, optional
Expand Down Expand Up @@ -86,7 +88,7 @@ def __init__(self,
self.end_date = end_date

self.agency_ids_to_preserve = agency_ids_to_preserve
self.gtfs = gtfs
self.gtfs = G
self.buffer_lat = buffer_lat
self.buffer_lon = buffer_lon
self.buffer_distance = buffer_distance
Expand All @@ -101,7 +103,7 @@ def __init__(self,
"the directory where the copied database will reside should exist beforehand"
assert not os.path.exists(copy_db_path), "the resulting database exists already: %s" % copy_db_path

def filter_extract(self):
def create_filtered_copy(self):
# this with statement
# is used to ensure that no corrupted/uncompleted files get created in case of problems
with util.create_file(self.copy_db_path) as tempfile:
Expand Down
84 changes: 29 additions & 55 deletions gtfspy/gtfs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from __future__ import print_function
from __future__ import unicode_literals

import calendar
import datetime
import logging
Expand All @@ -9,6 +6,7 @@
import sys
import time
from collections import Counter, defaultdict
from datetime import timedelta

import numpy
import pandas as pd
Expand All @@ -20,16 +18,12 @@
from gtfspy.route_types import WALK
from gtfspy.util import wgs84_distance

# py2/3 compatibility (copied from six)
if sys.version_info[0] == 3:
binary_type = bytes
else:
binary_type = str

if sys.getdefaultencoding() != 'utf-8':
reload(sys)
sys.setdefaultencoding('utf-8')


class GTFS(object):

def __init__(self, fname):
Expand Down Expand Up @@ -59,6 +53,7 @@ def __init__(self, fname):

assert self.conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchone() is not None
self.meta = GTFSMetadata(self.conn)

# Bind functions
self.conn.create_function("find_distance", 4, wgs84_distance)

Expand Down Expand Up @@ -624,51 +619,9 @@ def get_trip_counts_per_day(self):
# check that all date_strings are included (move this to tests?)
for date_string in trip_counts_per_day.index:
assert date_string in date_strings
data = {"date_str": date_strings, "trip_counts": trip_counts}
data = {"date": dates, "date_str": date_strings, "trip_counts": trip_counts}
return pd.DataFrame(data)

# Remove these pieces of code when this function has been tested:
#
# (RK) not sure if this works or not:
# def localized_datetime_to_ut_seconds(self, loc_dt):
# utcoffset = loc_dt.utcoffset()
# print utcoffset
# utc_naive = loc_dt.replace(tzinfo=None) - utcoffset
# timestamp = (utc_naive - datetime.datetime(1970, 1, 1)).total_seconds()
# return timestamp

# def
# query = "SELECT day_start_ut, count(*) AS number_of_trips FROM day_trips GROUP BY day_start_ut"
# trip_counts_per_day = pd.read_sql_query(query, self.conn, index_col="day_start_ut")
# min_day_start_ut = trip_counts_per_day.index.min()
# max_day_start_ut = trip_counts_per_day.index.max()
# spacing = 24*3600
# # day_start_ut is noon - 12 hours (to cover for daylight saving time changes)
# min_date_noon = self.ut_seconds_to_gtfs_datetime(min_day_start_ut)+datetime.timedelta(hours=12)
# max_date_noon = self.ut_seconds_to_gtfs_datetime(max_day_start_ut)+datetime.timedelta(hours=12)
# num_days = (max_date_noon-min_date_noon).days
# print min_date_noon, max_date_noon
# dates_noon = [min_date_noon + datetime.timedelta(days=x) for x in range(0, num_days+1)]
# day_noon_uts = [int(self.localized_datetime_to_ut_seconds(date)) for date in dates_noon]
# day_start_uts = [dnu-12*3600 for dnu in day_noon_uts]
# print day_start_uts
# print list(trip_counts_per_day.index.values)

# assert max_day_start_ut == day_start_uts[-1]
# assert min_day_start_ut == day_start_uts[0]

# trip_counts = []
# for dsut in day_start_uts:
# try:
# value = trip_counts_per_day.loc[dsut, 'number_of_trips']
# except KeyError as e:
# # set value to 0 if dsut is not present, i.e. when no trips
# # take place on that day
# value = 0
# trip_counts.append(value)
# for dsut in trip_counts_per_day.index:
# assert dsut in day_start_uts
# return {"day_start_uts": day_start_uts, "trip_counts":trip_counts}

def get_suitable_date_for_daily_extract(self, date=None, ut=False):
"""
Expand All @@ -684,7 +637,6 @@ def get_suitable_date_for_daily_extract(self, date=None, ut=False):
Iterates trough the available dates forward and backward from the download date accepting the first day that has
at least 90 percent of the number of trips of the maximum date. The condition can be changed to something else.
If the download date is out of range, the process will look through the dates from first to last.
"""
daily_trips = self.get_trip_counts_per_day()
max_daily_trips = daily_trips[u'trip_counts'].max(axis=0)
Expand All @@ -700,6 +652,28 @@ def get_suitable_date_for_daily_extract(self, date=None, ut=False):
else:
return row.date_str

def get_weekly_extract_start_date(self, ut=False, weekdays_at_least_of_max=0.9):
daily_trips = self.get_trip_counts_per_day()
download_date_str = self.meta['download_date']
if download_date_str == "":
raise RuntimeError("Download date is not speficied. Cannot find a suitable start date for week extract")
download_date = datetime.datetime.strptime(download_date_str, "%Y-%m-%d")
max_trip_count = daily_trips['trip_counts'].max()
threshold = weekdays_at_least_of_max * max_trip_count
threshold_fulfilling_days = daily_trips['trip_counts'] > threshold

next_monday = download_date + timedelta(days=(7 - download_date.weekday()))
monday_index = daily_trips[daily_trips['date'] == next_monday].index[0]
while len(daily_trips.index) >= monday_index + 7:
if all(threshold_fulfilling_days[monday_index:monday_index + 5]):
row = daily_trips.iloc[monday_index]
if ut:
return self.get_day_start_ut(row.date_str)
else:
return row['date']
monday_index += 7
raise RuntimeError("No suitable date could be found!")

def get_spreading_trips(self, start_time_ut, lat, lon,
max_duration_ut=4 * 3600,
min_transfer_time=30,
Expand Down Expand Up @@ -806,7 +780,7 @@ def get_route_name_and_type(self, route_I):
cur = self.conn.cursor()
results = cur.execute("SELECT name, type FROM routes WHERE route_I=(?)", (route_I,))
name, rtype = results.fetchone()
return unicode(name), int(rtype)
return name, int(rtype)

def get_trip_stop_coordinates(self, trip_I):
"""
Expand Down Expand Up @@ -1356,7 +1330,7 @@ def update_stats(self, stats):
self.meta.update(stats)
self.meta['stats_calc_at_ut'] = time.time()

def get_conservative_gtfs_time_span_in_ut(self):
def get_approximate_schedule_time_span_in_ut(self):
"""
Return conservative estimates of start_time_ut and end_time_uts.
All trips, events etc. should start after start_time_ut_conservative and end before end_time_ut_conservative
Expand Down Expand Up @@ -1443,7 +1417,7 @@ def __getitem__(self, key):

def __setitem__(self, key, value):
"""Get metadata from the DB"""
if isinstance(value, binary_type):
if isinstance(value, bytes):
value = value.decode('utf-8')
self._conn.execute('INSERT OR REPLACE INTO metadata '
'(key, value) VALUES (?, ?)',
Expand Down
12 changes: 6 additions & 6 deletions gtfspy/test/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def tearDown(self):

def test_copy(self):
# do a simple copy
FilterExtract(self.G, self.fname_copy, update_metadata=False).filter_extract()
FilterExtract(self.G, self.fname_copy, update_metadata=False).create_filtered_copy()

# check that the copying has been properly performed:
hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
Expand All @@ -44,7 +44,7 @@ def test_copy(self):

def test_filter_change_metadata(self):
# A simple test that changing update_metadata to True, does update some stuff:
FilterExtract(self.G, self.fname_copy, update_metadata=True).filter_extract()
FilterExtract(self.G, self.fname_copy, update_metadata=True).create_filtered_copy()
# check that the copying has been properly performed:
hash_orig = hashlib.md5(open(self.fname, 'rb').read()).hexdigest()
hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
Expand All @@ -53,7 +53,7 @@ def test_filter_change_metadata(self):
os.remove(self.fname_copy)

def test_filter_by_agency(self):
FilterExtract(self.G, self.fname_copy, agency_ids_to_preserve=['DTA']).filter_extract()
FilterExtract(self.G, self.fname_copy, agency_ids_to_preserve=['DTA']).create_filtered_copy()
hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
self.assertNotEqual(self.hash_orig, hash_copy)
G_copy = GTFS(self.fname_copy)
Expand All @@ -80,7 +80,7 @@ def test_filter_by_start_and_end(self):
# (Shapes are not provided in the test data currently)

# test filtering by start and end time, copy full range
FilterExtract(self.G, self.fname_copy, start_date=u"2007-01-01", end_date=u"2011-01-01", update_metadata=False).filter_extract()
FilterExtract(self.G, self.fname_copy, start_date=u"2007-01-01", end_date=u"2011-01-01", update_metadata=False).create_filtered_copy()
hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
# self.assertEqual(self.hash_orig, hash_copy)

Expand All @@ -91,7 +91,7 @@ def test_filter_by_start_and_end(self):
os.remove(self.fname_copy)

# the end date is not included:
FilterExtract(self.G, self.fname_copy, start_date="2007-01-02", end_date="2010-12-31").filter_extract()
FilterExtract(self.G, self.fname_copy, start_date="2007-01-02", end_date="2010-12-31").create_filtered_copy()
hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest()
self.assertNotEqual(self.hash_orig, hash_copy)
G_copy = GTFS(self.fname_copy)
Expand All @@ -111,7 +111,7 @@ def test_filter_by_start_and_end(self):

def test_filter_spatially(self):
# test that the db is split by a given spatial boundary
FilterExtract(self.G, self.fname_copy, buffer_lat=36.914893, buffer_lon=-116.76821, buffer_distance=50).filter_extract()
FilterExtract(self.G, self.fname_copy, buffer_lat=36.914893, buffer_lon=-116.76821, buffer_distance=50).create_filtered_copy()
G_copy = GTFS(self.fname_copy)

stops_table = G_copy.get_table("stops")
Expand Down
6 changes: 3 additions & 3 deletions gtfspy/test/test_gtfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_timezone_conversions(self):

def test_get_trip_trajectory_data_within_timespan(self):
# untested, really
s, e = self.gtfs.get_conservative_gtfs_time_span_in_ut()
s, e = self.gtfs.get_approximate_schedule_time_span_in_ut()
res = self.gtfs.get_trip_trajectories_within_timespan(s, s + 3600 * 24)
self.assertTrue(isinstance(res, dict))
# TODO! Not properly tested yet.
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_get_route_name_and_type_of_tripI(self):
self.assertTrue(isinstance(type_, int))

def test_get_trip_stop_time_data(self):
start_ut, end_ut = self.gtfs.get_conservative_gtfs_time_span_in_ut()
start_ut, end_ut = self.gtfs.get_approximate_schedule_time_span_in_ut()
dsut_dict = self.gtfs.get_tripIs_within_range_by_dsut(start_ut, end_ut)
dsut, trip_Is = list(dsut_dict.items())[0]
df = self.gtfs.get_trip_stop_time_data(trip_Is[0], dsut)
Expand All @@ -259,7 +259,7 @@ def test_get_straight_line_transfer_distances(self):
self.assertGreater(len(data), 0)

def test_get_conservative_gtfs_time_span_in_ut(self):
start_ut, end_ut = self.gtfs.get_conservative_gtfs_time_span_in_ut()
start_ut, end_ut = self.gtfs.get_approximate_schedule_time_span_in_ut()
start_dt = datetime.datetime(2007, 1, 1)
start_ut_comp = self.gtfs.unlocalized_datetime_to_ut_seconds(start_dt)
end_dt = datetime.datetime(2010, 12, 31)
Expand Down
Loading

0 comments on commit f33ea0b

Please sign in to comment.