diff --git a/src/national_rail_timetable/mca_queries.py b/src/national_rail_timetable/mca_queries.py index 223b8ad..907c592 100644 --- a/src/national_rail_timetable/mca_queries.py +++ b/src/national_rail_timetable/mca_queries.py @@ -75,6 +75,11 @@ class Schedule: cr: pd.DataFrame +@dataclass +class Service(Schedule): + date: SixDate + + @dataclass class Timetable: engine: Engine = field( @@ -192,6 +197,8 @@ class Timetable: ) ) + # TODO: This is slow, but could probably be ~7x'd through async if disk i/o allows. + # TODO: And/or a fetch_schedules which interprets multiple results in post. def fetch_schedule(self, schedule_number: int) -> Schedule: return Schedule( sn=schedule_number, @@ -208,9 +215,42 @@ class Timetable: ) +# Functions +def services_date_and_tiploc( + date: SixDate, + tiploc: str, + tt: Timetable | None = None, +): + tt = tt if tt is not None else Timetable() + on_date = tt.execute( + select(BS.schedule_number).where( + (BS.date_runs_from <= date) & (BS.date_runs_to >= date) + ) + ).schedule_number.values + origin = tt.execute( + select(LO.schedule_number).where(LO.location.like(f"%{tiploc}%")) + ).schedule_number.values + en_route = tt.execute( + select(LI.schedule_number).where(LI.location.like(f"%{tiploc}%")) + ).schedule_number.values + destination = tt.execute( + select(LT.schedule_number).where(LT.location.like(f"%{tiploc}%")) + ).schedule_number.values + sns = np.intersect1d( + np.array(on_date), np.unique([*origin, *en_route, *destination]) + ) + return [ + Service( + date=date, + **tt.fetch_schedule(int(sn)).__dict__, # pyright: ignore[reportAny] + ) + for sn in sns # pyright: ignore[reportAny] + ] + + # Script def main(): - print(Timetable().fetch_schedule(30)) + print(services_date_and_tiploc(SixDate("260524"), "CRMLNGT")) return None