Created
December 28, 2014 01:25
-
-
Save PirosB3/66ac4f20d5a07bdec3e7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import defaultdict | |
import glob | |
import json | |
import unittest | |
import numpy as np | |
class Station(object): | |
def __init__(self, name, latitude, longitude): | |
self.name = name | |
self.pos = np.array([latitude, longitude]) | |
self.directions = defaultdict(list) | |
def get_distance_from(self, point): | |
return np.linalg.norm(self.pos-point) | |
def __repr__(self): | |
return "<%s station instance>" % self.name | |
class Timetable(object): | |
def __init__(self, stations, data): | |
with open(stations) as f: | |
self.stations = dict( | |
(data['name'], Station(**data)) | |
for data in json.loads(f.read()) | |
) | |
for data_file in glob.iglob(data): | |
with open(data_file) as f: | |
deserialized = json.loads(f.read()) | |
direction_name = deserialized['towards'] | |
for station_name, timetable in deserialized['data'].iteritems(): | |
self.stations[station_name].directions[direction_name] = timetable | |
def get_nearest(self, pos): | |
mapped_distances = [ | |
(station.get_distance_from(pos), station,) | |
for station in self.stations.values() | |
] | |
mapped_distances = sorted(mapped_distances, key=lambda x: x[0]) | |
score, nearest_station = mapped_distances[0] | |
return nearest_station | |
class TimetableTestCase(unittest.TestCase): | |
def setUp(self): | |
self.tt = Timetable( | |
stations="datasets/stations.json", | |
data="datasets/data/*.json" | |
) | |
def test_it_can_read_stations(self): | |
self.assertEqual(13, len(self.tt.stations.keys())) | |
def test_it_can_find_nearest(self): | |
self.assertEqual("ACILIA", self.tt.get_nearest(np.array([41.78, 12.35])).name) | |
def test_it_can_get_directions(self): | |
self.assertEqual({'ostia', 'roma'}, set(self.tt.stations["ACILIA"].directions)) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment