Improved a few things, querying for multiple services now runs at a tolerable speed. Would prefer if it could be improved further, will look at pre-merging tables using sql rather than pandas.

This commit is contained in:
2026-05-25 21:21:53 +01:00
parent e723109a0a
commit 0479f1e4a8
+51 -30
View File
@@ -8,6 +8,7 @@ It seems unlikely that the CIF format will be modified any time soon.
""" """
# Imports # Imports
from collections.abc import Iterable
import sqlite3 import sqlite3
from datetime import datetime from datetime import datetime
from dataclasses import dataclass, field from dataclasses import dataclass, field
@@ -65,6 +66,14 @@ class SixDate(str):
def from_datetime(cls, dt: datetime): def from_datetime(cls, dt: datetime):
return cls.from_unix(str(dt).split(" ")[0]) return cls.from_unix(str(dt).split(" ")[0])
@property
def weekday(self) -> int:
return self.as_datetime().weekday()
@property
def weekday_like(self) -> str:
return "_" * self.weekday + "1%"
@dataclass @dataclass
class Schedule: class Schedule:
@@ -186,33 +195,40 @@ class Timetable:
with self.engine.connect() as connection: with self.engine.connect() as connection:
return pd.read_sql(query, connection) return pd.read_sql(query, connection)
def _fetch_record_of_schedule( def _fetch_records_of_schedules(
self, self,
schedule_number: int,
record_type: BaseRecord, record_type: BaseRecord,
*schedule_numbers: int,
) -> pd.DataFrame: ) -> pd.DataFrame:
return self.execute( return self.execute(
select(record_type.all).where( select(record_type.all).where(
record_type.schedule_number == schedule_number record_type.schedule_number.in_(schedule_numbers)
) )
) )
# TODO: This is slow, but could probably be ~7x'd through async if disk i/o allows. def fetch_schedules(self, *schedule_numbers: int) -> dict[int, Schedule]:
# TODO: And/or a fetch_schedules which interprets multiple results in post. bs = self._fetch_records_of_schedules(BS, *schedule_numbers)
def fetch_schedule(self, schedule_number: int) -> Schedule: bx = self._fetch_records_of_schedules(BX, *schedule_numbers)
return Schedule( lo = self._fetch_records_of_schedules(LO, *schedule_numbers)
sn=schedule_number, li = self._fetch_records_of_schedules(LI, *schedule_numbers)
bs=self._fetch_record_of_schedule(schedule_number, BS).iloc[0], lt = self._fetch_records_of_schedules(LT, *schedule_numbers)
bx=self._fetch_record_of_schedule(schedule_number, BX).iloc[0], cr = self._fetch_records_of_schedules(CR, *schedule_numbers)
return {
sn: Schedule(
sn=sn,
bs=bs[bs.schedule_number == sn].iloc[0],
bx=bx[bx.schedule_number == sn].iloc[0],
loit=pd.concat( loit=pd.concat(
[ [
self._fetch_record_of_schedule(schedule_number, LO), lo[lo.schedule_number == sn],
self._fetch_record_of_schedule(schedule_number, LI), li[li.schedule_number == sn],
self._fetch_record_of_schedule(schedule_number, LT), lt[lt.schedule_number == sn],
] ]
).reset_index(drop=True), ).reset_index(drop=True),
cr=self._fetch_record_of_schedule(schedule_number, CR), cr=cr[cr.schedule_number == sn],
) )
for sn in schedule_numbers
}
# Functions # Functions
@@ -222,36 +238,41 @@ def services_date_and_tiploc(
tt: Timetable | None = None, tt: Timetable | None = None,
): ):
tt = tt if tt is not None else Timetable() tt = tt if tt is not None else Timetable()
on_date = tt.execute( on_date: Iterable[int] = tt.execute(
select(BS.schedule_number).where( select(BS.schedule_number).where(
(BS.date_runs_from <= date) & (BS.date_runs_to >= date) (BS.date_runs_from <= date)
& (BS.date_runs_to >= date)
& (BS.days_run.like(date.weekday_like))
) )
).schedule_number.values ).schedule_number.values
origin = tt.execute( origin: Iterable[int] = tt.execute(
select(LO.schedule_number).where(LO.location.like(f"%{tiploc}%")) select(LO.schedule_number).where(LO.location == f"{tiploc:<8}")
).schedule_number.values ).schedule_number.values
en_route = tt.execute( en_route: Iterable[int] = tt.execute(
select(LI.schedule_number).where(LI.location.like(f"%{tiploc}%")) select(LI.schedule_number).where(
).schedule_number.values (LI.location == f"{tiploc:<8}")
destination = tt.execute( & (LI.scheduled_departure_time.not_like(" %"))
select(LT.schedule_number).where(LT.location.like(f"%{tiploc}%"))
).schedule_number.values
sns = np.intersect1d(
np.array(on_date), np.unique([*origin, *en_route, *destination])
) )
).schedule_number.values
destination: Iterable[int] = tt.execute(
select(LT.schedule_number).where(LT.location == f"{tiploc:<8}")
).schedule_number.values
sns = np.unique([*origin, *en_route, *destination])
sns = np.intersect1d(sns, on_date) # pyright: ignore[reportCallIssue, reportUnknownVariableType, reportArgumentType]
sns = [int(sn) for sn in sns] # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
return [ return [
Service( Service(
date=date, date=date,
**tt.fetch_schedule(int(sn)).__dict__, # pyright: ignore[reportAny] **schedule.__dict__, # pyright: ignore[reportAny]
) )
for sn in sns # pyright: ignore[reportAny] for _, schedule in tt.fetch_schedules(*sns).items()
] ]
# Script # Script
def main(): def main():
print(services_date_and_tiploc(SixDate("260524"), "CRMLNGT")) print(s := services_date_and_tiploc(SixDate("260525"), "YORK"))
return None return len(s)
if __name__ == "__main__": if __name__ == "__main__":