Tracked mca_querying.py.
This commit is contained in:
@@ -0,0 +1,191 @@
|
|||||||
|
"""
|
||||||
|
MCA file querying to retrieve schedule(d service)s.
|
||||||
|
"""
|
||||||
|
# pyright: reportAny=false
|
||||||
|
# pyright: reportUnknownVariableType=false
|
||||||
|
# pyright: reportUnknownArgumentType=false
|
||||||
|
# pyright: reportAttributeAccessIssue=false
|
||||||
|
# pyright: reportOperatorIssue=false
|
||||||
|
|
||||||
|
# Imports
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import Callable, Self
|
||||||
|
from zipfile import ZipFile
|
||||||
|
import numpy as np
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
from national_rail_timetable.nr_requests import fetch_nr_timetable_files
|
||||||
|
from national_rail_timetable.mca_stubs import BS, BX, LI, LO, LT, CR
|
||||||
|
|
||||||
|
|
||||||
|
# Functions
|
||||||
|
def ux(s1: NDArray[np.byte]) -> NDArray[np.str_]:
|
||||||
|
x = s1.shape[1]
|
||||||
|
s1.dtype = f"S{x}"
|
||||||
|
return s1.astype(f"U{x}")[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
def sx(s1: NDArray[np.byte]) -> NDArray[np.byte]:
|
||||||
|
x = s1.shape[1]
|
||||||
|
s1.dtype = f"S{x}"
|
||||||
|
return s1[:, 0]
|
||||||
|
|
||||||
|
|
||||||
|
# Classes
|
||||||
|
@dataclass
|
||||||
|
class Timetable:
|
||||||
|
array: NDArray[np.byte]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_zipfile(
|
||||||
|
cls,
|
||||||
|
zipfile: ZipFile | None = None,
|
||||||
|
) -> Self:
|
||||||
|
zipfile = zipfile if zipfile is not None else fetch_nr_timetable_files()
|
||||||
|
with TemporaryDirectory() as tempdir:
|
||||||
|
name = zipfile.extract(
|
||||||
|
[
|
||||||
|
zipinfo
|
||||||
|
for name, zipinfo in zipfile.NameToInfo.items()
|
||||||
|
if name.split(".")[-1] == "MCA"
|
||||||
|
][0],
|
||||||
|
path=tempdir,
|
||||||
|
)
|
||||||
|
array: NDArray[np.byte] = np.fromfile(
|
||||||
|
Path(tempdir) / name, dtype="S1"
|
||||||
|
).reshape((-1, 82))[:, :-2]
|
||||||
|
return cls(array=array)
|
||||||
|
|
||||||
|
def bs_mask(self) -> NDArray[np.bool]:
|
||||||
|
return (self.array[:, 0] == b"B") & (self.array[:, 1] == b"S")
|
||||||
|
|
||||||
|
def lo_mask(self) -> NDArray[np.bool]:
|
||||||
|
return (self.array[:, 0] == b"L") & (self.array[:, 1] == b"O")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sns(self) -> NDArray[np.integer]:
|
||||||
|
return np.repeat(
|
||||||
|
(
|
||||||
|
sns := [
|
||||||
|
0,
|
||||||
|
*np.arange(self.array.shape[0])[self.bs_mask()],
|
||||||
|
self.array.shape[0],
|
||||||
|
]
|
||||||
|
)[:-1],
|
||||||
|
np.diff(sns),
|
||||||
|
)
|
||||||
|
|
||||||
|
def fetch_schedules(self, *sns: int) -> NDArray[np.byte]:
|
||||||
|
return self.array[np.isin(self.sns, sns)]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Query:
|
||||||
|
tt: Timetable
|
||||||
|
sns: NDArray[np.integer] | None = None
|
||||||
|
|
||||||
|
def _query_from_mask(self, mask: NDArray[np.bool]) -> Query:
|
||||||
|
sns: NDArray[np.integer] = self.tt.sns[mask]
|
||||||
|
if self.sns is not None:
|
||||||
|
sns = np.intersect1d(sns, self.sns)
|
||||||
|
return Query(self.tt, sns)
|
||||||
|
|
||||||
|
def on_date(self, date: str) -> Query:
|
||||||
|
mask: NDArray[np.bool] = (
|
||||||
|
(sx(self.tt.array[:, LO().record_identity()]) == b"BS")
|
||||||
|
& (sx(self.tt.array[:, BS().date_runs_from()]) <= date.encode())
|
||||||
|
& (sx(self.tt.array[:, BS().date_runs_to()]) >= date.encode())
|
||||||
|
)
|
||||||
|
return self._query_from_mask(mask)
|
||||||
|
|
||||||
|
def origin(self, tiploc: str) -> Query:
|
||||||
|
mask: NDArray[np.bool] = (
|
||||||
|
(sx(self.tt.array[:, LO().record_identity()]) == b"LO")
|
||||||
|
& (
|
||||||
|
sx(self.tt.array[:, LO().location().start : LO().location().stop - 1])
|
||||||
|
== f"{tiploc:<7}".encode()
|
||||||
|
),
|
||||||
|
)[0]
|
||||||
|
return self._query_from_mask(mask)
|
||||||
|
|
||||||
|
def dest(self, tiploc: str, call_number: str | int | None = None) -> Query:
|
||||||
|
mask: NDArray[np.bool] = (
|
||||||
|
(sx(self.tt.array[:, LO().record_identity()]) == b"LT")
|
||||||
|
& (
|
||||||
|
sx(self.tt.array[:, LT().location().start : LT().location().stop - 1])
|
||||||
|
== f"{tiploc:<7}".encode()
|
||||||
|
),
|
||||||
|
)[0]
|
||||||
|
if call_number is not None:
|
||||||
|
mask &= (
|
||||||
|
sx(self.tt.array[:, LT().location().stop]) == str(call_number).encode()
|
||||||
|
)
|
||||||
|
return self._query_from_mask(mask)
|
||||||
|
|
||||||
|
def calls(self, tiploc: str, call_number: str | int | None = None) -> Query:
|
||||||
|
mask: NDArray[np.bool] = (
|
||||||
|
(sx(self.tt.array[:, LI().record_identity()]) == b"LI")
|
||||||
|
& (
|
||||||
|
sx(self.tt.array[:, LI().location().start : LI().location().stop - 1])
|
||||||
|
== f"{tiploc:<7}".encode()
|
||||||
|
),
|
||||||
|
)[0]
|
||||||
|
if call_number is not None:
|
||||||
|
mask &= (
|
||||||
|
sx(self.tt.array[:, LI().location().stop]) == str(call_number).encode()
|
||||||
|
)
|
||||||
|
return self._query_from_mask(mask) | self.origin(tiploc) | self.dest(tiploc)
|
||||||
|
|
||||||
|
def get_field(self, mca_field: Callable[..., slice]):
|
||||||
|
record_type = str(mca_field).split("method ")[1].split(".")[0]
|
||||||
|
return sx(
|
||||||
|
self.result[sx(self.result[:, :2]) == record_type.encode(), mca_field()]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _a(self):
|
||||||
|
return self.tt.array
|
||||||
|
|
||||||
|
@property
|
||||||
|
def result(self) -> NDArray[np.byte]:
|
||||||
|
assert self.sns is not None
|
||||||
|
return self.tt.fetch_schedules(*self.sns)
|
||||||
|
|
||||||
|
def as_sx(self) -> NDArray[np.byte]:
|
||||||
|
return sx(self.result)
|
||||||
|
|
||||||
|
def as_ux(self) -> NDArray[np.str_]:
|
||||||
|
return ux(self.result)
|
||||||
|
|
||||||
|
def __and__(self, other: Query) -> Query:
|
||||||
|
assert self.tt is other.tt
|
||||||
|
assert self.sns is not None
|
||||||
|
assert other.sns is not None
|
||||||
|
return Query(self.tt, np.intersect1d(self.sns, other.sns))
|
||||||
|
|
||||||
|
def __or__(self, other: Query) -> Query:
|
||||||
|
assert self.tt is other.tt
|
||||||
|
assert self.sns is not None
|
||||||
|
assert other.sns is not None
|
||||||
|
return Query(self.tt, np.union1d(self.sns, other.sns))
|
||||||
|
|
||||||
|
|
||||||
|
# Script
|
||||||
|
def main():
|
||||||
|
|
||||||
|
try:
|
||||||
|
tt = Timetable(
|
||||||
|
np.load(Path(__file__).parents[2] / "data/cache.mca.npy"),
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
tt = Timetable.from_zipfile()
|
||||||
|
np.save(Path(__file__).parents[2] / "data/cache.mca", tt.array)
|
||||||
|
|
||||||
|
result = Query(tt).on_date("260526").calls("CRMLNGT")
|
||||||
|
print(result.as_ux())
|
||||||
|
print(result.get_field(BX().retail_service_id))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user