diff --git a/src/national_rail_timetable/mca_queries.py b/src/national_rail_timetable/mca_queries.py index f7f473d..223b8ad 100644 --- a/src/national_rail_timetable/mca_queries.py +++ b/src/national_rail_timetable/mca_queries.py @@ -8,11 +8,13 @@ It seems unlikely that the CIF format will be modified any time soon. """ # Imports -from dataclasses import dataclass, field import sqlite3 +from datetime import datetime +from dataclasses import dataclass, field from pathlib import Path from typing import Any +import numpy as np import pandas as pd from sqlalchemy import ( Column, @@ -24,11 +26,55 @@ from sqlalchemy import ( select, ) +from national_rail_timetable.mca_record_types import ( + BS, + BX, + CR, + LI, + LO, + LT, + BaseRecord, +) from national_rail_timetable.parsing import validate_db_path -from national_rail_timetable.mca_record_types import LI # Classes +class SixDate(str): + def __post_init__(self): + assert len(self) == 6 + assert self.isnumeric() + + def as_unix(self): + return f"20{self[:2]}-{self[2:4]}-{self[-2:]}" + + def as_numpy(self): + return np.datetime64(self.as_unix()) + + def as_datetime(self): + return datetime.fromisoformat(self.as_unix()) + + @classmethod + def from_unix(cls, unix_str: str): + return cls(unix_str[2:].replace("-", "").replace("/", "")) + + @classmethod + def from_numpy(cls, np_dt: np.datetime64): + return cls.from_unix(str(np_dt).split("T")[0]) + + @classmethod + def from_datetime(cls, dt: datetime): + return cls.from_unix(str(dt).split(" ")[0]) + + +@dataclass +class Schedule: + sn: int + bs: pd.Series + bx: pd.Series + loit: pd.DataFrame + cr: pd.DataFrame + + @dataclass class Timetable: engine: Engine = field( @@ -76,7 +122,7 @@ class Timetable: + "# Result of mca_queries.py's Timetable._hardcode_table_dataclasses. \n" + "\n" + "# Imports \n" - + "from dataclasses import dataclass \n" + + "from dataclasses import dataclass, field \n" + "from typing import Any \n" + "from sqlalchemy import Column, MetaData, Table, String, Integer \n" + "\n" @@ -84,6 +130,13 @@ class Timetable: + "metadata = MetaData() \n" + "\n" + "# Classes \n" + + "@dataclass \n" + + "class BaseRecord: \n" + + "\tall: Table \n\n" + + "\t@property \n" + + "\tdef schedule_number(self) -> Column[Integer]: ... \n\n" + + "\t@property \n" + + "\tdef line_number(self) -> Column[Integer]: ... \n\n" ) for name in self.tables: columns = [column.name for column in self.tables[name].columns] @@ -96,7 +149,7 @@ class Timetable: ) text += ( "@dataclass \n" - + f"class _{rr}_base: \n" + + f"class _{rr}_base(BaseRecord): \n" + "\tall: Table \n\n" + "".join( [ @@ -128,10 +181,36 @@ class Timetable: with self.engine.connect() as connection: return pd.read_sql(query, connection) + def _fetch_record_of_schedule( + self, + schedule_number: int, + record_type: BaseRecord, + ) -> pd.DataFrame: + return self.execute( + select(record_type.all).where( + record_type.schedule_number == schedule_number + ) + ) + + def fetch_schedule(self, schedule_number: int) -> Schedule: + return Schedule( + sn=schedule_number, + bs=self._fetch_record_of_schedule(schedule_number, BS).iloc[0], + bx=self._fetch_record_of_schedule(schedule_number, BX).iloc[0], + loit=pd.concat( + [ + self._fetch_record_of_schedule(schedule_number, LO), + self._fetch_record_of_schedule(schedule_number, LI), + self._fetch_record_of_schedule(schedule_number, LT), + ] + ).reset_index(drop=True), + cr=self._fetch_record_of_schedule(schedule_number, CR), + ) + # Script def main(): - print(Timetable().execute(select(LI.all).where(LI.schedule_number == 10))) + print(Timetable().fetch_schedule(30)) return None diff --git a/src/national_rail_timetable/mca_record_types.py b/src/national_rail_timetable/mca_record_types.py index abd9447..3a164bb 100644 --- a/src/national_rail_timetable/mca_record_types.py +++ b/src/national_rail_timetable/mca_record_types.py @@ -5,7 +5,7 @@ # Result of mca_queries.py's Timetable._hardcode_table_dataclasses. # Imports -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any from sqlalchemy import Column, MetaData, Table, String, Integer @@ -13,6 +13,16 @@ from sqlalchemy import Column, MetaData, Table, String, Integer metadata = MetaData() # Classes +@dataclass +class BaseRecord: + all: Table + + @property + def schedule_number(self) -> Column[Integer]: ... + + @property + def line_number(self) -> Column[Integer]: ... + _BS_columns = Table( 'raw_mca_bs', metadata, @@ -47,7 +57,7 @@ _BS_columns = Table( ) @dataclass -class _BS_base: +class _BS_base(BaseRecord): all: Table @property @@ -183,7 +193,7 @@ _HD_columns = Table( ) @dataclass -class _HD_base: +class _HD_base(BaseRecord): all: Table @property @@ -250,7 +260,7 @@ _ZZ_columns = Table( ) @dataclass -class _ZZ_base: +class _ZZ_base(BaseRecord): all: Table @property @@ -291,7 +301,7 @@ _TA_columns = Table( ) @dataclass -class _TA_base: +class _TA_base(BaseRecord): all: Table @property @@ -382,7 +392,7 @@ _CR_columns = Table( ) @dataclass -class _CR_base: +class _CR_base(BaseRecord): all: Table @property @@ -499,7 +509,7 @@ _LT_columns = Table( ) @dataclass -class _LT_base: +class _LT_base(BaseRecord): all: Table @property @@ -567,7 +577,7 @@ _LI_columns = Table( ) @dataclass -class _LI_base: +class _LI_base(BaseRecord): all: Table @property @@ -651,7 +661,7 @@ _TD_columns = Table( ) @dataclass -class _TD_base: +class _TD_base(BaseRecord): all: Table @property @@ -700,7 +710,7 @@ _AA_columns = Table( ) @dataclass -class _AA_base: +class _AA_base(BaseRecord): all: Table @property @@ -796,7 +806,7 @@ _LO_columns = Table( ) @dataclass -class _LO_base: +class _LO_base(BaseRecord): all: Table @property @@ -869,7 +879,7 @@ _BX_columns = Table( ) @dataclass -class _BX_base: +class _BX_base(BaseRecord): all: Table @property @@ -933,7 +943,7 @@ _TI_columns = Table( ) @dataclass -class _TI_base: +class _TI_base(BaseRecord): all: Table @property