diff --git a/examples/example_filter.py b/examples/example_filter.py index 9c52526..50fec02 100644 --- a/examples/example_filter.py +++ b/examples/example_filter.py @@ -1 +1,20 @@ -# TODO! +import datetime +import os + +from example_import import load_or_import_example_gtfs +from gtfspy.gtfs import GTFS +from gtfspy.filter import FilterExtract + +G = load_or_import_example_gtfs() +assert isinstance(G, GTFS) + +filtered_database_path = "test_db_kuopio.week.sqlite" +if os.path.exists(filtered_database_path): + os.remove(filtered_database_path) + +week_start = G.get_weekly_extract_start_date() +week_end = week_start + datetime.timedelta(days=7) +fe = FilterExtract(G, filtered_database_path, start_date=week_start, end_date=week_end) + +fe.create_filtered_copy() +assert (os.path.exists(filtered_database_path)) diff --git a/examples/example_import.py b/examples/example_import.py index 68d48db..71559c3 100644 --- a/examples/example_import.py +++ b/examples/example_import.py @@ -5,7 +5,7 @@ from gtfspy import osm_transfers -def load_or_import_gtfs(verbose=False): +def load_or_import_example_gtfs(verbose=False): imported_database_path = "test_db_kuopio.sqlite" if not os.path.exists(imported_database_path): # reimport only if the imported database does not already exist print("Importing gtfs zip file") @@ -16,6 +16,8 @@ def load_or_import_gtfs(verbose=False): # Not this is an optional step, which is not necessary for many things. print("Computing walking paths using OSM") + G = gtfs.GTFS(imported_database_path) + G.meta['download_date'] = "2017-03-15" osm_path = "data/kuopio_extract_mapzen_2017_03_15.osm.pbf" @@ -33,10 +35,10 @@ def load_or_import_gtfs(verbose=False): if verbose: print("Location name:" + G.get_location_name()) # should print Kuopio - print("Time span of the data in unixtime: " + str(G.get_conservative_gtfs_time_span_in_ut())) + print("Time span of the data in unixtime: " + str(G.get_approximate_schedule_time_span_in_ut())) # prints the time span in unix time return G if __name__ == "__main__": - load_or_import_gtfs(verbose=True) + load_or_import_example_gtfs(verbose=True) diff --git a/examples/example_map_visualization.py b/examples/example_map_visualization.py index a85ea9e..3879377 100644 --- a/examples/example_map_visualization.py +++ b/examples/example_map_visualization.py @@ -1,8 +1,8 @@ from gtfspy import mapviz -from example_import import load_or_import_gtfs +from example_import import load_or_import_example_gtfs from matplotlib import pyplot as plt -g = load_or_import_gtfs() +g = load_or_import_example_gtfs() ax = mapviz.plot_route_network(g) ax = mapviz.plot_all_stops(g, ax) diff --git a/examples/example_plot_trip_counts.py b/examples/example_plot_trip_counts.py index e69de29..87945c9 100644 --- a/examples/example_plot_trip_counts.py +++ b/examples/example_plot_trip_counts.py @@ -0,0 +1,41 @@ +import functools +import os + +from example_import import load_or_import_example_gtfs +from matplotlib import pyplot as plt +from gtfspy.gtfs import GTFS + +G = load_or_import_example_gtfs() + +daily_trip_counts = G.get_trip_counts_per_day() +f, ax = plt.subplots() + +datetimes = [date.to_pydatetime() for date in daily_trip_counts['date']] +trip_counts = daily_trip_counts['trip_counts'] + +ax.bar(datetimes, trip_counts) +ax.axvline(G.meta['download_date'], color="red") +threshold = 0.96 +ax.axhline(trip_counts.max() * threshold, color="red") +ax.axvline(G.get_weekly_extract_start_date(weekdays_at_least_of_max=threshold), color="yellow") + +weekly_db_path = "test_db_kuopio.week.sqlite" +if os.path.exists(weekly_db_path): + G = GTFS(weekly_db_path) + f, ax = plt.subplots() + daily_trip_counts = G.get_trip_counts_per_day() + datetimes = [date.to_pydatetime() for date in daily_trip_counts['date']] + trip_counts = daily_trip_counts['trip_counts'] + ax.bar(datetimes, trip_counts) + + events = list(G.generate_routable_transit_events(0, G.get_approximate_schedule_time_span_in_ut()[0])) + min_ut = float('inf') + for e in events: + min_ut = min(e.dep_time_ut, min_ut) + + print(G.get_approximate_schedule_time_span_in_ut()) + +plt.show() + + + diff --git a/examples/example_temporal_distance_profile.py b/examples/example_temporal_distance_profile.py index 1fd1c24..09f5691 100644 --- a/examples/example_temporal_distance_profile.py +++ b/examples/example_temporal_distance_profile.py @@ -5,7 +5,7 @@ from matplotlib import rc import example_import -G = example_import.load_or_import_gtfs() +G = example_import.load_or_import_example_gtfs() from_stop_name = "Ahkiotie 2 E" to_stop_name = "Kauppahalli P" diff --git a/gtfspy/filter.py b/gtfspy/filter.py index bc60edd..e31120c 100644 --- a/gtfspy/filter.py +++ b/gtfspy/filter.py @@ -13,10 +13,10 @@ from gtfspy import gtfs - class FilterExtract(object): + def __init__(self, - gtfs, + G, copy_db_path, buffer_distance=None, buffer_lat=None, @@ -28,9 +28,11 @@ def __init__(self, agency_distance=None): """ Copy a database, and then based on various filters. - Only copy_and_filter method is provided as of now because we do not want to take the risk of - losing any data of the original databases. + Only method `create_filtered_copy` is provided as we do not want to take the risk of + losing the data stored in the original database. + G: gtfspy.gtfs.GTFS + the original database copy_db_path : str path to another database database update_metadata : boolean, optional @@ -86,7 +88,7 @@ def __init__(self, self.end_date = end_date self.agency_ids_to_preserve = agency_ids_to_preserve - self.gtfs = gtfs + self.gtfs = G self.buffer_lat = buffer_lat self.buffer_lon = buffer_lon self.buffer_distance = buffer_distance @@ -101,7 +103,7 @@ def __init__(self, "the directory where the copied database will reside should exist beforehand" assert not os.path.exists(copy_db_path), "the resulting database exists already: %s" % copy_db_path - def filter_extract(self): + def create_filtered_copy(self): # this with statement # is used to ensure that no corrupted/uncompleted files get created in case of problems with util.create_file(self.copy_db_path) as tempfile: diff --git a/gtfspy/gtfs.py b/gtfspy/gtfs.py index a189615..c1018bc 100644 --- a/gtfspy/gtfs.py +++ b/gtfspy/gtfs.py @@ -1,6 +1,3 @@ -from __future__ import print_function -from __future__ import unicode_literals - import calendar import datetime import logging @@ -9,6 +6,7 @@ import sys import time from collections import Counter, defaultdict +from datetime import timedelta import numpy import pandas as pd @@ -20,16 +18,12 @@ from gtfspy.route_types import WALK from gtfspy.util import wgs84_distance -# py2/3 compatibility (copied from six) -if sys.version_info[0] == 3: - binary_type = bytes -else: - binary_type = str if sys.getdefaultencoding() != 'utf-8': reload(sys) sys.setdefaultencoding('utf-8') + class GTFS(object): def __init__(self, fname): @@ -59,6 +53,7 @@ def __init__(self, fname): assert self.conn.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchone() is not None self.meta = GTFSMetadata(self.conn) + # Bind functions self.conn.create_function("find_distance", 4, wgs84_distance) @@ -624,51 +619,9 @@ def get_trip_counts_per_day(self): # check that all date_strings are included (move this to tests?) for date_string in trip_counts_per_day.index: assert date_string in date_strings - data = {"date_str": date_strings, "trip_counts": trip_counts} + data = {"date": dates, "date_str": date_strings, "trip_counts": trip_counts} return pd.DataFrame(data) - # Remove these pieces of code when this function has been tested: - # - # (RK) not sure if this works or not: - # def localized_datetime_to_ut_seconds(self, loc_dt): - # utcoffset = loc_dt.utcoffset() - # print utcoffset - # utc_naive = loc_dt.replace(tzinfo=None) - utcoffset - # timestamp = (utc_naive - datetime.datetime(1970, 1, 1)).total_seconds() - # return timestamp - - # def - # query = "SELECT day_start_ut, count(*) AS number_of_trips FROM day_trips GROUP BY day_start_ut" - # trip_counts_per_day = pd.read_sql_query(query, self.conn, index_col="day_start_ut") - # min_day_start_ut = trip_counts_per_day.index.min() - # max_day_start_ut = trip_counts_per_day.index.max() - # spacing = 24*3600 - # # day_start_ut is noon - 12 hours (to cover for daylight saving time changes) - # min_date_noon = self.ut_seconds_to_gtfs_datetime(min_day_start_ut)+datetime.timedelta(hours=12) - # max_date_noon = self.ut_seconds_to_gtfs_datetime(max_day_start_ut)+datetime.timedelta(hours=12) - # num_days = (max_date_noon-min_date_noon).days - # print min_date_noon, max_date_noon - # dates_noon = [min_date_noon + datetime.timedelta(days=x) for x in range(0, num_days+1)] - # day_noon_uts = [int(self.localized_datetime_to_ut_seconds(date)) for date in dates_noon] - # day_start_uts = [dnu-12*3600 for dnu in day_noon_uts] - # print day_start_uts - # print list(trip_counts_per_day.index.values) - - # assert max_day_start_ut == day_start_uts[-1] - # assert min_day_start_ut == day_start_uts[0] - - # trip_counts = [] - # for dsut in day_start_uts: - # try: - # value = trip_counts_per_day.loc[dsut, 'number_of_trips'] - # except KeyError as e: - # # set value to 0 if dsut is not present, i.e. when no trips - # # take place on that day - # value = 0 - # trip_counts.append(value) - # for dsut in trip_counts_per_day.index: - # assert dsut in day_start_uts - # return {"day_start_uts": day_start_uts, "trip_counts":trip_counts} def get_suitable_date_for_daily_extract(self, date=None, ut=False): """ @@ -684,7 +637,6 @@ def get_suitable_date_for_daily_extract(self, date=None, ut=False): Iterates trough the available dates forward and backward from the download date accepting the first day that has at least 90 percent of the number of trips of the maximum date. The condition can be changed to something else. If the download date is out of range, the process will look through the dates from first to last. - """ daily_trips = self.get_trip_counts_per_day() max_daily_trips = daily_trips[u'trip_counts'].max(axis=0) @@ -700,6 +652,28 @@ def get_suitable_date_for_daily_extract(self, date=None, ut=False): else: return row.date_str + def get_weekly_extract_start_date(self, ut=False, weekdays_at_least_of_max=0.9): + daily_trips = self.get_trip_counts_per_day() + download_date_str = self.meta['download_date'] + if download_date_str == "": + raise RuntimeError("Download date is not speficied. Cannot find a suitable start date for week extract") + download_date = datetime.datetime.strptime(download_date_str, "%Y-%m-%d") + max_trip_count = daily_trips['trip_counts'].max() + threshold = weekdays_at_least_of_max * max_trip_count + threshold_fulfilling_days = daily_trips['trip_counts'] > threshold + + next_monday = download_date + timedelta(days=(7 - download_date.weekday())) + monday_index = daily_trips[daily_trips['date'] == next_monday].index[0] + while len(daily_trips.index) >= monday_index + 7: + if all(threshold_fulfilling_days[monday_index:monday_index + 5]): + row = daily_trips.iloc[monday_index] + if ut: + return self.get_day_start_ut(row.date_str) + else: + return row['date'] + monday_index += 7 + raise RuntimeError("No suitable date could be found!") + def get_spreading_trips(self, start_time_ut, lat, lon, max_duration_ut=4 * 3600, min_transfer_time=30, @@ -806,7 +780,7 @@ def get_route_name_and_type(self, route_I): cur = self.conn.cursor() results = cur.execute("SELECT name, type FROM routes WHERE route_I=(?)", (route_I,)) name, rtype = results.fetchone() - return unicode(name), int(rtype) + return name, int(rtype) def get_trip_stop_coordinates(self, trip_I): """ @@ -1356,7 +1330,7 @@ def update_stats(self, stats): self.meta.update(stats) self.meta['stats_calc_at_ut'] = time.time() - def get_conservative_gtfs_time_span_in_ut(self): + def get_approximate_schedule_time_span_in_ut(self): """ Return conservative estimates of start_time_ut and end_time_uts. All trips, events etc. should start after start_time_ut_conservative and end before end_time_ut_conservative @@ -1443,7 +1417,7 @@ def __getitem__(self, key): def __setitem__(self, key, value): """Get metadata from the DB""" - if isinstance(value, binary_type): + if isinstance(value, bytes): value = value.decode('utf-8') self._conn.execute('INSERT OR REPLACE INTO metadata ' '(key, value) VALUES (?, ?)', diff --git a/gtfspy/test/test_filter.py b/gtfspy/test/test_filter.py index 59b146b..f173e32 100644 --- a/gtfspy/test/test_filter.py +++ b/gtfspy/test/test_filter.py @@ -35,7 +35,7 @@ def tearDown(self): def test_copy(self): # do a simple copy - FilterExtract(self.G, self.fname_copy, update_metadata=False).filter_extract() + FilterExtract(self.G, self.fname_copy, update_metadata=False).create_filtered_copy() # check that the copying has been properly performed: hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest() @@ -44,7 +44,7 @@ def test_copy(self): def test_filter_change_metadata(self): # A simple test that changing update_metadata to True, does update some stuff: - FilterExtract(self.G, self.fname_copy, update_metadata=True).filter_extract() + FilterExtract(self.G, self.fname_copy, update_metadata=True).create_filtered_copy() # check that the copying has been properly performed: hash_orig = hashlib.md5(open(self.fname, 'rb').read()).hexdigest() hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest() @@ -53,7 +53,7 @@ def test_filter_change_metadata(self): os.remove(self.fname_copy) def test_filter_by_agency(self): - FilterExtract(self.G, self.fname_copy, agency_ids_to_preserve=['DTA']).filter_extract() + FilterExtract(self.G, self.fname_copy, agency_ids_to_preserve=['DTA']).create_filtered_copy() hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest() self.assertNotEqual(self.hash_orig, hash_copy) G_copy = GTFS(self.fname_copy) @@ -80,7 +80,7 @@ def test_filter_by_start_and_end(self): # (Shapes are not provided in the test data currently) # test filtering by start and end time, copy full range - FilterExtract(self.G, self.fname_copy, start_date=u"2007-01-01", end_date=u"2011-01-01", update_metadata=False).filter_extract() + FilterExtract(self.G, self.fname_copy, start_date=u"2007-01-01", end_date=u"2011-01-01", update_metadata=False).create_filtered_copy() hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest() # self.assertEqual(self.hash_orig, hash_copy) @@ -91,7 +91,7 @@ def test_filter_by_start_and_end(self): os.remove(self.fname_copy) # the end date is not included: - FilterExtract(self.G, self.fname_copy, start_date="2007-01-02", end_date="2010-12-31").filter_extract() + FilterExtract(self.G, self.fname_copy, start_date="2007-01-02", end_date="2010-12-31").create_filtered_copy() hash_copy = hashlib.md5(open(self.fname_copy, 'rb').read()).hexdigest() self.assertNotEqual(self.hash_orig, hash_copy) G_copy = GTFS(self.fname_copy) @@ -111,7 +111,7 @@ def test_filter_by_start_and_end(self): def test_filter_spatially(self): # test that the db is split by a given spatial boundary - FilterExtract(self.G, self.fname_copy, buffer_lat=36.914893, buffer_lon=-116.76821, buffer_distance=50).filter_extract() + FilterExtract(self.G, self.fname_copy, buffer_lat=36.914893, buffer_lon=-116.76821, buffer_distance=50).create_filtered_copy() G_copy = GTFS(self.fname_copy) stops_table = G_copy.get_table("stops") diff --git a/gtfspy/test/test_gtfs.py b/gtfspy/test/test_gtfs.py index 85b6a40..4271161 100644 --- a/gtfspy/test/test_gtfs.py +++ b/gtfspy/test/test_gtfs.py @@ -133,7 +133,7 @@ def test_timezone_conversions(self): def test_get_trip_trajectory_data_within_timespan(self): # untested, really - s, e = self.gtfs.get_conservative_gtfs_time_span_in_ut() + s, e = self.gtfs.get_approximate_schedule_time_span_in_ut() res = self.gtfs.get_trip_trajectories_within_timespan(s, s + 3600 * 24) self.assertTrue(isinstance(res, dict)) # TODO! Not properly tested yet. @@ -233,7 +233,7 @@ def test_get_route_name_and_type_of_tripI(self): self.assertTrue(isinstance(type_, int)) def test_get_trip_stop_time_data(self): - start_ut, end_ut = self.gtfs.get_conservative_gtfs_time_span_in_ut() + start_ut, end_ut = self.gtfs.get_approximate_schedule_time_span_in_ut() dsut_dict = self.gtfs.get_tripIs_within_range_by_dsut(start_ut, end_ut) dsut, trip_Is = list(dsut_dict.items())[0] df = self.gtfs.get_trip_stop_time_data(trip_Is[0], dsut) @@ -259,7 +259,7 @@ def test_get_straight_line_transfer_distances(self): self.assertGreater(len(data), 0) def test_get_conservative_gtfs_time_span_in_ut(self): - start_ut, end_ut = self.gtfs.get_conservative_gtfs_time_span_in_ut() + start_ut, end_ut = self.gtfs.get_approximate_schedule_time_span_in_ut() start_dt = datetime.datetime(2007, 1, 1) start_ut_comp = self.gtfs.unlocalized_datetime_to_ut_seconds(start_dt) end_dt = datetime.datetime(2010, 12, 31) diff --git a/gtfspy/timetable_validator.py b/gtfspy/timetable_validator.py index 3b94948..089d6f9 100644 --- a/gtfspy/timetable_validator.py +++ b/gtfspy/timetable_validator.py @@ -30,6 +30,7 @@ WARNING_STOP_SEQUENCE_ERROR } + class TimetableValidator(object): def __init__(self, gtfs): @@ -62,7 +63,9 @@ def get_warnings(self): self._validate_stop_sequence() self.warnings_container.print_summary() return self.warnings_container -# TODO: check for missplaced stops in the filtered feed, by checking outside a buffer + x distance. (Routes going outside are okay if they return inside the buffer) + + # TODO: check for misplaced stops in the filtered feed, by checking outside a buffer + x distance. (Routes going outside are okay if they return inside the buffer) + def _validate_stops_with_same_stop_time(self): n_stops_with_same_time = 5 # this query returns the trips where there are N or more stops with the same stop time