diff --git a/src/national_rail_timetable/mca_queries.py b/src/national_rail_timetable/mca_queries.py index 907c592..638ccaa 100644 --- a/src/national_rail_timetable/mca_queries.py +++ b/src/national_rail_timetable/mca_queries.py @@ -8,6 +8,7 @@ It seems unlikely that the CIF format will be modified any time soon. """ # Imports +from collections.abc import Iterable import sqlite3 from datetime import datetime from dataclasses import dataclass, field @@ -65,6 +66,14 @@ class SixDate(str): def from_datetime(cls, dt: datetime): return cls.from_unix(str(dt).split(" ")[0]) + @property + def weekday(self) -> int: + return self.as_datetime().weekday() + + @property + def weekday_like(self) -> str: + return "_" * self.weekday + "1%" + @dataclass class Schedule: @@ -186,33 +195,40 @@ class Timetable: with self.engine.connect() as connection: return pd.read_sql(query, connection) - def _fetch_record_of_schedule( + def _fetch_records_of_schedules( self, - schedule_number: int, record_type: BaseRecord, + *schedule_numbers: int, ) -> pd.DataFrame: return self.execute( select(record_type.all).where( - record_type.schedule_number == schedule_number + record_type.schedule_number.in_(schedule_numbers) ) ) - # TODO: This is slow, but could probably be ~7x'd through async if disk i/o allows. - # TODO: And/or a fetch_schedules which interprets multiple results in post. - def fetch_schedule(self, schedule_number: int) -> Schedule: - return Schedule( - sn=schedule_number, - bs=self._fetch_record_of_schedule(schedule_number, BS).iloc[0], - bx=self._fetch_record_of_schedule(schedule_number, BX).iloc[0], - loit=pd.concat( - [ - self._fetch_record_of_schedule(schedule_number, LO), - self._fetch_record_of_schedule(schedule_number, LI), - self._fetch_record_of_schedule(schedule_number, LT), - ] - ).reset_index(drop=True), - cr=self._fetch_record_of_schedule(schedule_number, CR), - ) + def fetch_schedules(self, *schedule_numbers: int) -> dict[int, Schedule]: + bs = self._fetch_records_of_schedules(BS, *schedule_numbers) + bx = self._fetch_records_of_schedules(BX, *schedule_numbers) + lo = self._fetch_records_of_schedules(LO, *schedule_numbers) + li = self._fetch_records_of_schedules(LI, *schedule_numbers) + lt = self._fetch_records_of_schedules(LT, *schedule_numbers) + cr = self._fetch_records_of_schedules(CR, *schedule_numbers) + return { + sn: Schedule( + sn=sn, + bs=bs[bs.schedule_number == sn].iloc[0], + bx=bx[bx.schedule_number == sn].iloc[0], + loit=pd.concat( + [ + lo[lo.schedule_number == sn], + li[li.schedule_number == sn], + lt[lt.schedule_number == sn], + ] + ).reset_index(drop=True), + cr=cr[cr.schedule_number == sn], + ) + for sn in schedule_numbers + } # Functions @@ -222,36 +238,41 @@ def services_date_and_tiploc( tt: Timetable | None = None, ): tt = tt if tt is not None else Timetable() - on_date = tt.execute( + on_date: Iterable[int] = tt.execute( select(BS.schedule_number).where( - (BS.date_runs_from <= date) & (BS.date_runs_to >= date) + (BS.date_runs_from <= date) + & (BS.date_runs_to >= date) + & (BS.days_run.like(date.weekday_like)) ) ).schedule_number.values - origin = tt.execute( - select(LO.schedule_number).where(LO.location.like(f"%{tiploc}%")) + origin: Iterable[int] = tt.execute( + select(LO.schedule_number).where(LO.location == f"{tiploc:<8}") ).schedule_number.values - en_route = tt.execute( - select(LI.schedule_number).where(LI.location.like(f"%{tiploc}%")) + en_route: Iterable[int] = tt.execute( + select(LI.schedule_number).where( + (LI.location == f"{tiploc:<8}") + & (LI.scheduled_departure_time.not_like(" %")) + ) ).schedule_number.values - destination = tt.execute( - select(LT.schedule_number).where(LT.location.like(f"%{tiploc}%")) + destination: Iterable[int] = tt.execute( + select(LT.schedule_number).where(LT.location == f"{tiploc:<8}") ).schedule_number.values - sns = np.intersect1d( - np.array(on_date), np.unique([*origin, *en_route, *destination]) - ) + sns = np.unique([*origin, *en_route, *destination]) + sns = np.intersect1d(sns, on_date) # pyright: ignore[reportCallIssue, reportUnknownVariableType, reportArgumentType] + sns = [int(sn) for sn in sns] # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] return [ Service( date=date, - **tt.fetch_schedule(int(sn)).__dict__, # pyright: ignore[reportAny] + **schedule.__dict__, # pyright: ignore[reportAny] ) - for sn in sns # pyright: ignore[reportAny] + for _, schedule in tt.fetch_schedules(*sns).items() ] # Script def main(): - print(services_date_and_tiploc(SixDate("260524"), "CRMLNGT")) - return None + print(s := services_date_and_tiploc(SixDate("260525"), "YORK")) + return len(s) if __name__ == "__main__":