Tracked mca_querying.py.

This commit is contained in:
2026-05-26 17:33:25 +01:00
parent c7bb1608a1
commit cd81b7514a
+191
View File
@@ -0,0 +1,191 @@
"""
MCA file querying to retrieve schedule(d service)s.
"""
# pyright: reportAny=false
# pyright: reportUnknownVariableType=false
# pyright: reportUnknownArgumentType=false
# pyright: reportAttributeAccessIssue=false
# pyright: reportOperatorIssue=false
# Imports
from dataclasses import dataclass
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, Self
from zipfile import ZipFile
import numpy as np
from numpy.typing import NDArray
from national_rail_timetable.nr_requests import fetch_nr_timetable_files
from national_rail_timetable.mca_stubs import BS, BX, LI, LO, LT, CR
# Functions
def ux(s1: NDArray[np.byte]) -> NDArray[np.str_]:
x = s1.shape[1]
s1.dtype = f"S{x}"
return s1.astype(f"U{x}")[:, 0]
def sx(s1: NDArray[np.byte]) -> NDArray[np.byte]:
x = s1.shape[1]
s1.dtype = f"S{x}"
return s1[:, 0]
# Classes
@dataclass
class Timetable:
array: NDArray[np.byte]
@classmethod
def from_zipfile(
cls,
zipfile: ZipFile | None = None,
) -> Self:
zipfile = zipfile if zipfile is not None else fetch_nr_timetable_files()
with TemporaryDirectory() as tempdir:
name = zipfile.extract(
[
zipinfo
for name, zipinfo in zipfile.NameToInfo.items()
if name.split(".")[-1] == "MCA"
][0],
path=tempdir,
)
array: NDArray[np.byte] = np.fromfile(
Path(tempdir) / name, dtype="S1"
).reshape((-1, 82))[:, :-2]
return cls(array=array)
def bs_mask(self) -> NDArray[np.bool]:
return (self.array[:, 0] == b"B") & (self.array[:, 1] == b"S")
def lo_mask(self) -> NDArray[np.bool]:
return (self.array[:, 0] == b"L") & (self.array[:, 1] == b"O")
@property
def sns(self) -> NDArray[np.integer]:
return np.repeat(
(
sns := [
0,
*np.arange(self.array.shape[0])[self.bs_mask()],
self.array.shape[0],
]
)[:-1],
np.diff(sns),
)
def fetch_schedules(self, *sns: int) -> NDArray[np.byte]:
return self.array[np.isin(self.sns, sns)]
@dataclass
class Query:
tt: Timetable
sns: NDArray[np.integer] | None = None
def _query_from_mask(self, mask: NDArray[np.bool]) -> Query:
sns: NDArray[np.integer] = self.tt.sns[mask]
if self.sns is not None:
sns = np.intersect1d(sns, self.sns)
return Query(self.tt, sns)
def on_date(self, date: str) -> Query:
mask: NDArray[np.bool] = (
(sx(self.tt.array[:, LO().record_identity()]) == b"BS")
& (sx(self.tt.array[:, BS().date_runs_from()]) <= date.encode())
& (sx(self.tt.array[:, BS().date_runs_to()]) >= date.encode())
)
return self._query_from_mask(mask)
def origin(self, tiploc: str) -> Query:
mask: NDArray[np.bool] = (
(sx(self.tt.array[:, LO().record_identity()]) == b"LO")
& (
sx(self.tt.array[:, LO().location().start : LO().location().stop - 1])
== f"{tiploc:<7}".encode()
),
)[0]
return self._query_from_mask(mask)
def dest(self, tiploc: str, call_number: str | int | None = None) -> Query:
mask: NDArray[np.bool] = (
(sx(self.tt.array[:, LO().record_identity()]) == b"LT")
& (
sx(self.tt.array[:, LT().location().start : LT().location().stop - 1])
== f"{tiploc:<7}".encode()
),
)[0]
if call_number is not None:
mask &= (
sx(self.tt.array[:, LT().location().stop]) == str(call_number).encode()
)
return self._query_from_mask(mask)
def calls(self, tiploc: str, call_number: str | int | None = None) -> Query:
mask: NDArray[np.bool] = (
(sx(self.tt.array[:, LI().record_identity()]) == b"LI")
& (
sx(self.tt.array[:, LI().location().start : LI().location().stop - 1])
== f"{tiploc:<7}".encode()
),
)[0]
if call_number is not None:
mask &= (
sx(self.tt.array[:, LI().location().stop]) == str(call_number).encode()
)
return self._query_from_mask(mask) | self.origin(tiploc) | self.dest(tiploc)
def get_field(self, mca_field: Callable[..., slice]):
record_type = str(mca_field).split("method ")[1].split(".")[0]
return sx(
self.result[sx(self.result[:, :2]) == record_type.encode(), mca_field()]
)
@property
def _a(self):
return self.tt.array
@property
def result(self) -> NDArray[np.byte]:
assert self.sns is not None
return self.tt.fetch_schedules(*self.sns)
def as_sx(self) -> NDArray[np.byte]:
return sx(self.result)
def as_ux(self) -> NDArray[np.str_]:
return ux(self.result)
def __and__(self, other: Query) -> Query:
assert self.tt is other.tt
assert self.sns is not None
assert other.sns is not None
return Query(self.tt, np.intersect1d(self.sns, other.sns))
def __or__(self, other: Query) -> Query:
assert self.tt is other.tt
assert self.sns is not None
assert other.sns is not None
return Query(self.tt, np.union1d(self.sns, other.sns))
# Script
def main():
try:
tt = Timetable(
np.load(Path(__file__).parents[2] / "data/cache.mca.npy"),
)
except FileNotFoundError:
tt = Timetable.from_zipfile()
np.save(Path(__file__).parents[2] / "data/cache.mca", tt.array)
result = Query(tt).on_date("260526").calls("CRMLNGT")
print(result.as_ux())
print(result.get_field(BX().retail_service_id))
if __name__ == "__main__":
main()