Tracked mca_querying.py.
This commit is contained in:
@@ -0,0 +1,191 @@
|
||||
"""
|
||||
MCA file querying to retrieve schedule(d service)s.
|
||||
"""
|
||||
# pyright: reportAny=false
|
||||
# pyright: reportUnknownVariableType=false
|
||||
# pyright: reportUnknownArgumentType=false
|
||||
# pyright: reportAttributeAccessIssue=false
|
||||
# pyright: reportOperatorIssue=false
|
||||
|
||||
# Imports
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Callable, Self
|
||||
from zipfile import ZipFile
|
||||
import numpy as np
|
||||
from numpy.typing import NDArray
|
||||
from national_rail_timetable.nr_requests import fetch_nr_timetable_files
|
||||
from national_rail_timetable.mca_stubs import BS, BX, LI, LO, LT, CR
|
||||
|
||||
|
||||
# Functions
|
||||
def ux(s1: NDArray[np.byte]) -> NDArray[np.str_]:
|
||||
x = s1.shape[1]
|
||||
s1.dtype = f"S{x}"
|
||||
return s1.astype(f"U{x}")[:, 0]
|
||||
|
||||
|
||||
def sx(s1: NDArray[np.byte]) -> NDArray[np.byte]:
|
||||
x = s1.shape[1]
|
||||
s1.dtype = f"S{x}"
|
||||
return s1[:, 0]
|
||||
|
||||
|
||||
# Classes
|
||||
@dataclass
|
||||
class Timetable:
|
||||
array: NDArray[np.byte]
|
||||
|
||||
@classmethod
|
||||
def from_zipfile(
|
||||
cls,
|
||||
zipfile: ZipFile | None = None,
|
||||
) -> Self:
|
||||
zipfile = zipfile if zipfile is not None else fetch_nr_timetable_files()
|
||||
with TemporaryDirectory() as tempdir:
|
||||
name = zipfile.extract(
|
||||
[
|
||||
zipinfo
|
||||
for name, zipinfo in zipfile.NameToInfo.items()
|
||||
if name.split(".")[-1] == "MCA"
|
||||
][0],
|
||||
path=tempdir,
|
||||
)
|
||||
array: NDArray[np.byte] = np.fromfile(
|
||||
Path(tempdir) / name, dtype="S1"
|
||||
).reshape((-1, 82))[:, :-2]
|
||||
return cls(array=array)
|
||||
|
||||
def bs_mask(self) -> NDArray[np.bool]:
|
||||
return (self.array[:, 0] == b"B") & (self.array[:, 1] == b"S")
|
||||
|
||||
def lo_mask(self) -> NDArray[np.bool]:
|
||||
return (self.array[:, 0] == b"L") & (self.array[:, 1] == b"O")
|
||||
|
||||
@property
|
||||
def sns(self) -> NDArray[np.integer]:
|
||||
return np.repeat(
|
||||
(
|
||||
sns := [
|
||||
0,
|
||||
*np.arange(self.array.shape[0])[self.bs_mask()],
|
||||
self.array.shape[0],
|
||||
]
|
||||
)[:-1],
|
||||
np.diff(sns),
|
||||
)
|
||||
|
||||
def fetch_schedules(self, *sns: int) -> NDArray[np.byte]:
|
||||
return self.array[np.isin(self.sns, sns)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Query:
|
||||
tt: Timetable
|
||||
sns: NDArray[np.integer] | None = None
|
||||
|
||||
def _query_from_mask(self, mask: NDArray[np.bool]) -> Query:
|
||||
sns: NDArray[np.integer] = self.tt.sns[mask]
|
||||
if self.sns is not None:
|
||||
sns = np.intersect1d(sns, self.sns)
|
||||
return Query(self.tt, sns)
|
||||
|
||||
def on_date(self, date: str) -> Query:
|
||||
mask: NDArray[np.bool] = (
|
||||
(sx(self.tt.array[:, LO().record_identity()]) == b"BS")
|
||||
& (sx(self.tt.array[:, BS().date_runs_from()]) <= date.encode())
|
||||
& (sx(self.tt.array[:, BS().date_runs_to()]) >= date.encode())
|
||||
)
|
||||
return self._query_from_mask(mask)
|
||||
|
||||
def origin(self, tiploc: str) -> Query:
|
||||
mask: NDArray[np.bool] = (
|
||||
(sx(self.tt.array[:, LO().record_identity()]) == b"LO")
|
||||
& (
|
||||
sx(self.tt.array[:, LO().location().start : LO().location().stop - 1])
|
||||
== f"{tiploc:<7}".encode()
|
||||
),
|
||||
)[0]
|
||||
return self._query_from_mask(mask)
|
||||
|
||||
def dest(self, tiploc: str, call_number: str | int | None = None) -> Query:
|
||||
mask: NDArray[np.bool] = (
|
||||
(sx(self.tt.array[:, LO().record_identity()]) == b"LT")
|
||||
& (
|
||||
sx(self.tt.array[:, LT().location().start : LT().location().stop - 1])
|
||||
== f"{tiploc:<7}".encode()
|
||||
),
|
||||
)[0]
|
||||
if call_number is not None:
|
||||
mask &= (
|
||||
sx(self.tt.array[:, LT().location().stop]) == str(call_number).encode()
|
||||
)
|
||||
return self._query_from_mask(mask)
|
||||
|
||||
def calls(self, tiploc: str, call_number: str | int | None = None) -> Query:
|
||||
mask: NDArray[np.bool] = (
|
||||
(sx(self.tt.array[:, LI().record_identity()]) == b"LI")
|
||||
& (
|
||||
sx(self.tt.array[:, LI().location().start : LI().location().stop - 1])
|
||||
== f"{tiploc:<7}".encode()
|
||||
),
|
||||
)[0]
|
||||
if call_number is not None:
|
||||
mask &= (
|
||||
sx(self.tt.array[:, LI().location().stop]) == str(call_number).encode()
|
||||
)
|
||||
return self._query_from_mask(mask) | self.origin(tiploc) | self.dest(tiploc)
|
||||
|
||||
def get_field(self, mca_field: Callable[..., slice]):
|
||||
record_type = str(mca_field).split("method ")[1].split(".")[0]
|
||||
return sx(
|
||||
self.result[sx(self.result[:, :2]) == record_type.encode(), mca_field()]
|
||||
)
|
||||
|
||||
@property
|
||||
def _a(self):
|
||||
return self.tt.array
|
||||
|
||||
@property
|
||||
def result(self) -> NDArray[np.byte]:
|
||||
assert self.sns is not None
|
||||
return self.tt.fetch_schedules(*self.sns)
|
||||
|
||||
def as_sx(self) -> NDArray[np.byte]:
|
||||
return sx(self.result)
|
||||
|
||||
def as_ux(self) -> NDArray[np.str_]:
|
||||
return ux(self.result)
|
||||
|
||||
def __and__(self, other: Query) -> Query:
|
||||
assert self.tt is other.tt
|
||||
assert self.sns is not None
|
||||
assert other.sns is not None
|
||||
return Query(self.tt, np.intersect1d(self.sns, other.sns))
|
||||
|
||||
def __or__(self, other: Query) -> Query:
|
||||
assert self.tt is other.tt
|
||||
assert self.sns is not None
|
||||
assert other.sns is not None
|
||||
return Query(self.tt, np.union1d(self.sns, other.sns))
|
||||
|
||||
|
||||
# Script
|
||||
def main():
|
||||
|
||||
try:
|
||||
tt = Timetable(
|
||||
np.load(Path(__file__).parents[2] / "data/cache.mca.npy"),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
tt = Timetable.from_zipfile()
|
||||
np.save(Path(__file__).parents[2] / "data/cache.mca", tt.array)
|
||||
|
||||
result = Query(tt).on_date("260526").calls("CRMLNGT")
|
||||
print(result.as_ux())
|
||||
print(result.get_field(BX().retail_service_id))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user