diff --git a/src/national_rail_timetable/mca_querying.py b/src/national_rail_timetable/mca_querying.py new file mode 100644 index 0000000..2fd88dc --- /dev/null +++ b/src/national_rail_timetable/mca_querying.py @@ -0,0 +1,191 @@ +""" +MCA file querying to retrieve schedule(d service)s. +""" +# pyright: reportAny=false +# pyright: reportUnknownVariableType=false +# pyright: reportUnknownArgumentType=false +# pyright: reportAttributeAccessIssue=false +# pyright: reportOperatorIssue=false + +# Imports +from dataclasses import dataclass +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Callable, Self +from zipfile import ZipFile +import numpy as np +from numpy.typing import NDArray +from national_rail_timetable.nr_requests import fetch_nr_timetable_files +from national_rail_timetable.mca_stubs import BS, BX, LI, LO, LT, CR + + +# Functions +def ux(s1: NDArray[np.byte]) -> NDArray[np.str_]: + x = s1.shape[1] + s1.dtype = f"S{x}" + return s1.astype(f"U{x}")[:, 0] + + +def sx(s1: NDArray[np.byte]) -> NDArray[np.byte]: + x = s1.shape[1] + s1.dtype = f"S{x}" + return s1[:, 0] + + +# Classes +@dataclass +class Timetable: + array: NDArray[np.byte] + + @classmethod + def from_zipfile( + cls, + zipfile: ZipFile | None = None, + ) -> Self: + zipfile = zipfile if zipfile is not None else fetch_nr_timetable_files() + with TemporaryDirectory() as tempdir: + name = zipfile.extract( + [ + zipinfo + for name, zipinfo in zipfile.NameToInfo.items() + if name.split(".")[-1] == "MCA" + ][0], + path=tempdir, + ) + array: NDArray[np.byte] = np.fromfile( + Path(tempdir) / name, dtype="S1" + ).reshape((-1, 82))[:, :-2] + return cls(array=array) + + def bs_mask(self) -> NDArray[np.bool]: + return (self.array[:, 0] == b"B") & (self.array[:, 1] == b"S") + + def lo_mask(self) -> NDArray[np.bool]: + return (self.array[:, 0] == b"L") & (self.array[:, 1] == b"O") + + @property + def sns(self) -> NDArray[np.integer]: + return np.repeat( + ( + sns := [ + 0, + *np.arange(self.array.shape[0])[self.bs_mask()], + self.array.shape[0], + ] + )[:-1], + np.diff(sns), + ) + + def fetch_schedules(self, *sns: int) -> NDArray[np.byte]: + return self.array[np.isin(self.sns, sns)] + + +@dataclass +class Query: + tt: Timetable + sns: NDArray[np.integer] | None = None + + def _query_from_mask(self, mask: NDArray[np.bool]) -> Query: + sns: NDArray[np.integer] = self.tt.sns[mask] + if self.sns is not None: + sns = np.intersect1d(sns, self.sns) + return Query(self.tt, sns) + + def on_date(self, date: str) -> Query: + mask: NDArray[np.bool] = ( + (sx(self.tt.array[:, LO().record_identity()]) == b"BS") + & (sx(self.tt.array[:, BS().date_runs_from()]) <= date.encode()) + & (sx(self.tt.array[:, BS().date_runs_to()]) >= date.encode()) + ) + return self._query_from_mask(mask) + + def origin(self, tiploc: str) -> Query: + mask: NDArray[np.bool] = ( + (sx(self.tt.array[:, LO().record_identity()]) == b"LO") + & ( + sx(self.tt.array[:, LO().location().start : LO().location().stop - 1]) + == f"{tiploc:<7}".encode() + ), + )[0] + return self._query_from_mask(mask) + + def dest(self, tiploc: str, call_number: str | int | None = None) -> Query: + mask: NDArray[np.bool] = ( + (sx(self.tt.array[:, LO().record_identity()]) == b"LT") + & ( + sx(self.tt.array[:, LT().location().start : LT().location().stop - 1]) + == f"{tiploc:<7}".encode() + ), + )[0] + if call_number is not None: + mask &= ( + sx(self.tt.array[:, LT().location().stop]) == str(call_number).encode() + ) + return self._query_from_mask(mask) + + def calls(self, tiploc: str, call_number: str | int | None = None) -> Query: + mask: NDArray[np.bool] = ( + (sx(self.tt.array[:, LI().record_identity()]) == b"LI") + & ( + sx(self.tt.array[:, LI().location().start : LI().location().stop - 1]) + == f"{tiploc:<7}".encode() + ), + )[0] + if call_number is not None: + mask &= ( + sx(self.tt.array[:, LI().location().stop]) == str(call_number).encode() + ) + return self._query_from_mask(mask) | self.origin(tiploc) | self.dest(tiploc) + + def get_field(self, mca_field: Callable[..., slice]): + record_type = str(mca_field).split("method ")[1].split(".")[0] + return sx( + self.result[sx(self.result[:, :2]) == record_type.encode(), mca_field()] + ) + + @property + def _a(self): + return self.tt.array + + @property + def result(self) -> NDArray[np.byte]: + assert self.sns is not None + return self.tt.fetch_schedules(*self.sns) + + def as_sx(self) -> NDArray[np.byte]: + return sx(self.result) + + def as_ux(self) -> NDArray[np.str_]: + return ux(self.result) + + def __and__(self, other: Query) -> Query: + assert self.tt is other.tt + assert self.sns is not None + assert other.sns is not None + return Query(self.tt, np.intersect1d(self.sns, other.sns)) + + def __or__(self, other: Query) -> Query: + assert self.tt is other.tt + assert self.sns is not None + assert other.sns is not None + return Query(self.tt, np.union1d(self.sns, other.sns)) + + +# Script +def main(): + + try: + tt = Timetable( + np.load(Path(__file__).parents[2] / "data/cache.mca.npy"), + ) + except FileNotFoundError: + tt = Timetable.from_zipfile() + np.save(Path(__file__).parents[2] / "data/cache.mca", tt.array) + + result = Query(tt).on_date("260526").calls("CRMLNGT") + print(result.as_ux()) + print(result.get_field(BX().retail_service_id)) + + +if __name__ == "__main__": + main()