Commit 626551d4 authored by dagal's avatar dagal
Browse files

sql network from gtfs

parent 19e75a57
......@@ -14,7 +14,7 @@ from pyqtree import Index
import utm
seconds_day = 24*60*60
days_cycle = 31
days_cycle = 30
def parse_time(time_str):
time_val = [int(t) for t in time_str.split(":")]
......@@ -105,10 +105,31 @@ class Network():
return trip
def compute_trips_by_stop(self):
def compute_trips_by_stop(self, prefix=None):
tripsByStop = [{} for _ in xrange(7)]
lineStops = {}
if prefix:
fls = open(prefix + "_linestops.sql", "w")
for trip in self.trips.values():
if prefix:
line = trip.get_line()
stops = [s for (_,_,s) in trip.stops]
if line in lineStops:
if lineStops[line] != stops:
sys.stderr.write(line + ": " + lineStops[line] + " vs " + stops + "\n")
else:
lineStops[line] = stops
if len(set(stops)) == len(stops):
for i,s in enumerate(stops):
mystr = "INSERT INTO line_stop (line_id, stop_id, seq) VALUES ('%s', '%s', %d);\n" % (line, s, i+1)
fls.write(mystr.encode("utf-8"))
else:
sys.stderr.write("Skipping line " + line + "\n")
for day in [d[0] for d in enumerate(trip.days) if d[1]]:
for _,end_time,stop in trip.stops:
if stop in tripsByStop[day]:
......@@ -116,15 +137,21 @@ class Network():
else:
tripsByStop[day][stop] = [(end_time, trip.id)]
if prefix:
fls.close()
for day in tripsByStop:
for trips in day.values():
trips.sort()
return tripsByStop
def calculate_trips_by_day(self, stop_dict, cycle):
def calculate_trips_by_day(self, stop_dict, cycle, prefix=None):
tripsByDay = [{} for _ in xrange(cycle)]
tripCounter = {}
if prefix:
fj = open(prefix + "_journeys.sql", "w")
for day in xrange(cycle):
#print "DAY " + str(day)
......@@ -145,10 +172,18 @@ class Network():
tripCounter[line] = 1
tripsByDay[day][(line, trip.start_time)] = c
if prefix:
mystr = "INSERT INTO journey (line_id, start_time) VALUES ('%s', to_timestamp(%d) at time zone 'UTC');\n" % (line, day*seconds_day + trip.start_time)
fj.write(mystr.encode("utf-8"))
#print trip.get_line() + ": "
#print ",".join(["%s-%s" % (encode_time(s[0]), stop_dict[s[2]]) for s in trip.stops])
if prefix:
fj.close()
return tripsByDay
class TDay():
......@@ -191,22 +226,40 @@ class TTime():
#return self.strftime("%H:%M")
return str(self.val())
def parse_gtfs(file_in, file_out, file_freqs = None, network = Network()):
def parse_gtfs(file_in, file_out, file_freqs = None, network = Network(), prefix = None):
loader = transitfeed.Loader(file_in, problems = transitfeed.problems.ProblemReporter())
sched = loader.Load()
if prefix:
flines = open(prefix + "_lines.sql", "w")
fstops = open(prefix + "_stops.sql", "w")
for r in sched.GetRouteList():
mystr = "INSERT INTO line (id, short_name, long_name) VALUES ('%sd0', '%s', '%s');\n" % (r.route_id, r.route_short_name, r.route_long_name)
flines.write(mystr.encode("utf-8"))
mystr = "INSERT INTO line (id, short_name, long_name) VALUES ('%sd1', '%s(R)', '%s (vuelta)');\n" % (r.route_id, r.route_short_name, r.route_long_name)
flines.write(mystr.encode("utf-8"))
for t in sched.GetTripList():
days = sched.GetServicePeriod(t.service_id).day_of_week
trip = Trip(t.trip_id, t.route_id, t.direction_id, days)
for stop_time in t.GetStopTimes():
s = sched.GetStop(stop_time.stop_id)
if prefix and s.stop_id not in network.stops:
mystr = "INSERT INTO stop (id, name, lat, lon) VALUES ('%s', '%s', %f, %f);\n" % (s.stop_id, s.stop_name, s.stop_lat, s.stop_lon)
fstops.write(mystr.encode("utf-8"))
stop = network.add_stop(Stop(s.stop_id, lat=s.stop_lat, lng=s.stop_lon))
trip.add_stop(parse_time(stop_time.arrival_time), parse_time(stop_time.departure_time), stop)
network.add_trip(trip)
if prefix:
flines.close()
fstops.close()
if file_freqs:
file = open(file_freqs)
for line in file:
......@@ -288,18 +341,46 @@ def load_subway(prefix, network):
st1.lines.add(line.id)
st2.lines.add(line.id)
def load_stops(path, nStops=20000):
stops = [set()] * nStops
with open(path) as stopLines:
for textLine in stopLines:
splitLine = textLine.strip().split(": ")
if len(splitLine) > 1:
[stopStr, lineStr] = splitLine
stops[int(stopStr)] = set(lineStr.split(","))
return stops
def main(argv):
#n_traj = 0
n_traj = 10000000
#n_traj = 1
n_traj = 1000
#n_traj = 10000000
#change_probs = [0.50, 0.90, 0.95, 0.98, 1.0]
change_probs = [0.98, 0.98, 0.99, 1.0]
#change_probs = [0.98, 0.98, 0.99, 1.0]
switch_weights = [10, 10, 5, 1]
changes = collections.Counter()
lengths = collections.Counter()
#network = parse_gtfs("madrid_emt.zip", None)
#network = parse_gtfs("madrid_bus.zip", "madrid_bus.dat", network=network)
network = load_gtfs("madrid_bus.dat")
#print("BEGIN TRANSACTION;")
#network = parse_gtfs("madrid_emt.zip", "madrid_emt.dat", prefix="emt")
#network = parse_gtfs("madrid_crtm.zip", "madrid_bus.dat", network=network)
#print("COMMIT;")
#sys.exit(0)
network = load_gtfs("madrid_emt.dat")
#tripsByStop = network.compute_trips_by_stop(prefix="emt")
tripsByStop = network.compute_trips_by_stop()
#sys.exit(0)
valid_stops = None
if len(argv) > 1:
valid_stops = load_stops(argv[1], len(network.stops)+1)
#load_subway("london", network)
#parse_gtfs("Madrid.zip", "madrid.dat")
......@@ -311,20 +392,55 @@ def main(argv):
#print len(network.lines.keys())
#network.assign_freqs()
stops = network.stops.values()
stops_dict = {key: value+1 for value, key in enumerate(sorted(network.stops.keys()))}
#tripsByDay = network.calculate_trips_by_day(stops_dict, days_cycle, prefix="emt")
tripsByDay = network.calculate_trips_by_day(stops_dict, days_cycle)
#sys.exit(0)
unused_stops = set(network.stops)
i = 0
max_waiting_time = 30*60
err = 0
skip_bad = 0
n_switch = [n_traj/x for x in switch_weights]
while i < n_traj:
for switch,n in enumerate(n_switch):
i = 0
while i < n:
current_day = getRandomDay()
origin = network.stops[random.choice(tripsByStop[current_day.val() % 7].keys())]
(t, trip_id) = random.choice(tripsByStop[current_day.val() % 7][origin.id])
current_trip = network.trips[trip_id]
current_line = current_trip.get_line()
next_stops = [s for s in current_trip.stops if s[1] > t and s[2] != origin]
if len(next_stops) == 0:
skip_bad += 1
continue
(t,_,next) = random.choice(next_stops)
trajectory = [origin.id, next]
lines = [current_line, current_line]
times = 2*[tripsByDay[current_day.val()][(current_line, current_trip.start_time)]]
changes.update([switch])
lengths.update([len(trajectory)])
unused_stops.difference_update(trajectory)
sys.stdout.write(",".join(map(lambda (l,s,t): "%s:%s:%s" % (l,str(stops_dict[s]),str(t)), zip(lines, trajectory, times))))
sys.stdout.write("\n")
i += 1
break
while i < 0:
current_day = getRandomDay()
origin = network.stops[random.choice(tripsByStop[current_day.val() % 7].keys())]
(t, trip_id) = random.choice(tripsByStop[current_day.val() % 7][origin.id])
current_trip = network.trips[trip_id]
print "Current trip: " + str(current_trip.id)
next_stops = [s for s in reversed(current_trip.stops) if s[1] > t and s[2] != origin]
print next_stops
if len(next_stops) == 0:
continue
......@@ -338,6 +454,10 @@ def main(argv):
lines = [current_line]
(t,_,next) = next_stops.pop()
complete_trajectory = [origin.id, next]
print complete_trajectory
print trajectory
print lines
print times
try:
while random.random() > prob_next:
......@@ -377,6 +497,13 @@ def main(argv):
t = prev_t
prob_next = 1
else:
if valid_stops:
if current_line not in valid_stops[stops_dict[prev]]:
print("DAMN")
print((stops_dict[prev], current_line))
skip_bad += 1
break
(t, current_trip) = random.choice(possible_trips)
next_stops = [s for s in reversed(current_trip.stops) if s[1] > t]
......@@ -388,6 +515,12 @@ def main(argv):
current_time = TTime(current_day, 0) + t
cur_changes += 1
if current_line not in valid_stops[stops_dict[next]]:
print("DAMN")
print((stops_dict[next], current_line))
skip_bad += 1
break
lines.append(current_line)
trajectory.append(next)
times.append(tripsByDay[current_day.val()][(current_line, current_trip.start_time)])
......@@ -409,7 +542,7 @@ def main(argv):
times.append(tripsByDay[current_day.val()][(current_line, current_trip.start_time)])
lines.append(current_line)
if len(trajectory) % 2 == 1:
if len(trajectory) % 2 == 1 or any([]):
# Just discard this piece of shit
err += 1
continue
......@@ -429,6 +562,8 @@ def main(argv):
if unused_stops:
sys.stderr.write("\nWARNING: " + str(len(unused_stops)) + " Unused stops\n")
sys.stderr.write("\nSKIPPED: " + str(skip_bad) + "\n")
sys.stderr.write("\nERRORS: " + str(err) + "\n")
sys.stderr.write("\nLENGTHS: " + str(lengths) + "\n")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment