Various minor updates, basic Schedule class. Added a SixData class to manage conversions of YYMMDD to/from more pythonic objects.

This commit is contained in:
2026-05-25 14:06:35 +01:00
parent c2633952d3
commit 36aa23f464
2 changed files with 107 additions and 18 deletions
+84 -5
View File
@@ -8,11 +8,13 @@ It seems unlikely that the CIF format will be modified any time soon.
""" """
# Imports # Imports
from dataclasses import dataclass, field
import sqlite3 import sqlite3
from datetime import datetime
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import numpy as np
import pandas as pd import pandas as pd
from sqlalchemy import ( from sqlalchemy import (
Column, Column,
@@ -24,11 +26,55 @@ from sqlalchemy import (
select, select,
) )
from national_rail_timetable.mca_record_types import (
BS,
BX,
CR,
LI,
LO,
LT,
BaseRecord,
)
from national_rail_timetable.parsing import validate_db_path from national_rail_timetable.parsing import validate_db_path
from national_rail_timetable.mca_record_types import LI
# Classes # Classes
class SixDate(str):
def __post_init__(self):
assert len(self) == 6
assert self.isnumeric()
def as_unix(self):
return f"20{self[:2]}-{self[2:4]}-{self[-2:]}"
def as_numpy(self):
return np.datetime64(self.as_unix())
def as_datetime(self):
return datetime.fromisoformat(self.as_unix())
@classmethod
def from_unix(cls, unix_str: str):
return cls(unix_str[2:].replace("-", "").replace("/", ""))
@classmethod
def from_numpy(cls, np_dt: np.datetime64):
return cls.from_unix(str(np_dt).split("T")[0])
@classmethod
def from_datetime(cls, dt: datetime):
return cls.from_unix(str(dt).split(" ")[0])
@dataclass
class Schedule:
sn: int
bs: pd.Series
bx: pd.Series
loit: pd.DataFrame
cr: pd.DataFrame
@dataclass @dataclass
class Timetable: class Timetable:
engine: Engine = field( engine: Engine = field(
@@ -76,7 +122,7 @@ class Timetable:
+ "# Result of mca_queries.py's Timetable._hardcode_table_dataclasses. \n" + "# Result of mca_queries.py's Timetable._hardcode_table_dataclasses. \n"
+ "\n" + "\n"
+ "# Imports \n" + "# Imports \n"
+ "from dataclasses import dataclass \n" + "from dataclasses import dataclass, field \n"
+ "from typing import Any \n" + "from typing import Any \n"
+ "from sqlalchemy import Column, MetaData, Table, String, Integer \n" + "from sqlalchemy import Column, MetaData, Table, String, Integer \n"
+ "\n" + "\n"
@@ -84,6 +130,13 @@ class Timetable:
+ "metadata = MetaData() \n" + "metadata = MetaData() \n"
+ "\n" + "\n"
+ "# Classes \n" + "# Classes \n"
+ "@dataclass \n"
+ "class BaseRecord: \n"
+ "\tall: Table \n\n"
+ "\t@property \n"
+ "\tdef schedule_number(self) -> Column[Integer]: ... \n\n"
+ "\t@property \n"
+ "\tdef line_number(self) -> Column[Integer]: ... \n\n"
) )
for name in self.tables: for name in self.tables:
columns = [column.name for column in self.tables[name].columns] columns = [column.name for column in self.tables[name].columns]
@@ -96,7 +149,7 @@ class Timetable:
) )
text += ( text += (
"@dataclass \n" "@dataclass \n"
+ f"class _{rr}_base: \n" + f"class _{rr}_base(BaseRecord): \n"
+ "\tall: Table \n\n" + "\tall: Table \n\n"
+ "".join( + "".join(
[ [
@@ -128,10 +181,36 @@ class Timetable:
with self.engine.connect() as connection: with self.engine.connect() as connection:
return pd.read_sql(query, connection) return pd.read_sql(query, connection)
def _fetch_record_of_schedule(
self,
schedule_number: int,
record_type: BaseRecord,
) -> pd.DataFrame:
return self.execute(
select(record_type.all).where(
record_type.schedule_number == schedule_number
)
)
def fetch_schedule(self, schedule_number: int) -> Schedule:
return Schedule(
sn=schedule_number,
bs=self._fetch_record_of_schedule(schedule_number, BS).iloc[0],
bx=self._fetch_record_of_schedule(schedule_number, BX).iloc[0],
loit=pd.concat(
[
self._fetch_record_of_schedule(schedule_number, LO),
self._fetch_record_of_schedule(schedule_number, LI),
self._fetch_record_of_schedule(schedule_number, LT),
]
).reset_index(drop=True),
cr=self._fetch_record_of_schedule(schedule_number, CR),
)
# Script # Script
def main(): def main():
print(Timetable().execute(select(LI.all).where(LI.schedule_number == 10))) print(Timetable().fetch_schedule(30))
return None return None
+23 -13
View File
@@ -5,7 +5,7 @@
# Result of mca_queries.py's Timetable._hardcode_table_dataclasses. # Result of mca_queries.py's Timetable._hardcode_table_dataclasses.
# Imports # Imports
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any from typing import Any
from sqlalchemy import Column, MetaData, Table, String, Integer from sqlalchemy import Column, MetaData, Table, String, Integer
@@ -13,6 +13,16 @@ from sqlalchemy import Column, MetaData, Table, String, Integer
metadata = MetaData() metadata = MetaData()
# Classes # Classes
@dataclass
class BaseRecord:
all: Table
@property
def schedule_number(self) -> Column[Integer]: ...
@property
def line_number(self) -> Column[Integer]: ...
_BS_columns = Table( _BS_columns = Table(
'raw_mca_bs', 'raw_mca_bs',
metadata, metadata,
@@ -47,7 +57,7 @@ _BS_columns = Table(
) )
@dataclass @dataclass
class _BS_base: class _BS_base(BaseRecord):
all: Table all: Table
@property @property
@@ -183,7 +193,7 @@ _HD_columns = Table(
) )
@dataclass @dataclass
class _HD_base: class _HD_base(BaseRecord):
all: Table all: Table
@property @property
@@ -250,7 +260,7 @@ _ZZ_columns = Table(
) )
@dataclass @dataclass
class _ZZ_base: class _ZZ_base(BaseRecord):
all: Table all: Table
@property @property
@@ -291,7 +301,7 @@ _TA_columns = Table(
) )
@dataclass @dataclass
class _TA_base: class _TA_base(BaseRecord):
all: Table all: Table
@property @property
@@ -382,7 +392,7 @@ _CR_columns = Table(
) )
@dataclass @dataclass
class _CR_base: class _CR_base(BaseRecord):
all: Table all: Table
@property @property
@@ -499,7 +509,7 @@ _LT_columns = Table(
) )
@dataclass @dataclass
class _LT_base: class _LT_base(BaseRecord):
all: Table all: Table
@property @property
@@ -567,7 +577,7 @@ _LI_columns = Table(
) )
@dataclass @dataclass
class _LI_base: class _LI_base(BaseRecord):
all: Table all: Table
@property @property
@@ -651,7 +661,7 @@ _TD_columns = Table(
) )
@dataclass @dataclass
class _TD_base: class _TD_base(BaseRecord):
all: Table all: Table
@property @property
@@ -700,7 +710,7 @@ _AA_columns = Table(
) )
@dataclass @dataclass
class _AA_base: class _AA_base(BaseRecord):
all: Table all: Table
@property @property
@@ -796,7 +806,7 @@ _LO_columns = Table(
) )
@dataclass @dataclass
class _LO_base: class _LO_base(BaseRecord):
all: Table all: Table
@property @property
@@ -869,7 +879,7 @@ _BX_columns = Table(
) )
@dataclass @dataclass
class _BX_base: class _BX_base(BaseRecord):
all: Table all: Table
@property @property
@@ -933,7 +943,7 @@ _TI_columns = Table(
) )
@dataclass @dataclass
class _TI_base: class _TI_base(BaseRecord):
all: Table all: Table
@property @property