diff --git a/pyproject.toml b/pyproject.toml index 8757633..493c0f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,8 @@ dependencies = [ "pypdf (>=6.12.0,<7.0.0)", "pandas (>=3.0.3,<4.0.0)", "pandas-stubs (>=3.0.0.260204,<4.0.0.0)", - "sqlalchemy (>=2.0.49,<3.0.0)" + "sqlalchemy (>=2.0.49,<3.0.0)", + "py-spy (>=0.4.2,<0.5.0)" ] diff --git a/src/national_rail_timetable/mca_queries.py b/src/national_rail_timetable/mca_queries.py deleted file mode 100644 index 638ccaa..0000000 --- a/src/national_rail_timetable/mca_queries.py +++ /dev/null @@ -1,279 +0,0 @@ -""" -Queries for the 'raw_mca_...' tables generated in parsing.py. -Thus far, attempts at few assumptions for record types have been made. -These queries will outright expect certain properties of the databases generated. -If they suddenly stop working - it should be due to an RSP specification change. -Therefore, the error handling will be less graceful as this is not predictable. -It seems unlikely that the CIF format will be modified any time soon. -""" - -# Imports -from collections.abc import Iterable -import sqlite3 -from datetime import datetime -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - -import numpy as np -import pandas as pd -from sqlalchemy import ( - Column, - Engine, - MetaData, - Select, - Table, - create_engine, - select, -) - -from national_rail_timetable.mca_record_types import ( - BS, - BX, - CR, - LI, - LO, - LT, - BaseRecord, -) -from national_rail_timetable.parsing import validate_db_path - - -# Classes -class SixDate(str): - def __post_init__(self): - assert len(self) == 6 - assert self.isnumeric() - - def as_unix(self): - return f"20{self[:2]}-{self[2:4]}-{self[-2:]}" - - def as_numpy(self): - return np.datetime64(self.as_unix()) - - def as_datetime(self): - return datetime.fromisoformat(self.as_unix()) - - @classmethod - def from_unix(cls, unix_str: str): - return cls(unix_str[2:].replace("-", "").replace("/", "")) - - @classmethod - def from_numpy(cls, np_dt: np.datetime64): - return cls.from_unix(str(np_dt).split("T")[0]) - - @classmethod - def from_datetime(cls, dt: datetime): - return cls.from_unix(str(dt).split(" ")[0]) - - @property - def weekday(self) -> int: - return self.as_datetime().weekday() - - @property - def weekday_like(self) -> str: - return "_" * self.weekday + "1%" - - -@dataclass -class Schedule: - sn: int - bs: pd.Series - bx: pd.Series - loit: pd.DataFrame - cr: pd.DataFrame - - -@dataclass -class Service(Schedule): - date: SixDate - - -@dataclass -class Timetable: - engine: Engine = field( - default_factory=lambda: create_engine( - url=f"sqlite:///{validate_db_path().as_posix()}" - ) - ) - metadata: MetaData = field(default_factory=MetaData) - tables: dict[str, Table] = field(default_factory=dict) - - def __post_init__(self): - self._populate_self_tables() - - def _populate_self_tables(self): - cursor = (connection := self._generate_sqlite3_connection()).cursor() - self.tables |= { - name: Table( - name, # pyright: ignore[reportAny] - self.metadata, - *[ - Column(d[0]) - for d in cursor.execute(f"SELECT * FROM {name} LIMIT 1").description - ], - ) - for name in [ # pyright: ignore[reportAny] - row[0] - for row in cursor.execute( # pyright: ignore[reportAny] - "SELECT name FROM sqlite_master WHERE type = 'table'" - ).fetchall() - if row[0][:8] == "raw_mca_" - ] - } - connection.close() - - def _generate_sqlite3_connection(self) -> sqlite3.Connection: - return sqlite3.connect(self.engine.url.__to_string__().split("///")[1]) - - # TODO: Implement docstrings from 'spec_mca_...' tables - def _hardcode_table_dataclasses(self): - text: str = ( - "# This file is pre-generated for type-hinting while writing MCA file queries. \n" - + "# Any changes made manually will likely be overwritten. \n" - + "# It should not need to be generated more than once. \n" - + "# If the RSP's timetable specification changes, then this will need to be updated. \n" - + "# Result of mca_queries.py's Timetable._hardcode_table_dataclasses. \n" - + "\n" - + "# Imports \n" - + "from dataclasses import dataclass, field \n" - + "from typing import Any \n" - + "from sqlalchemy import Column, MetaData, Table, String, Integer \n" - + "\n" - + "# Init. \n" - + "metadata = MetaData() \n" - + "\n" - + "# Classes \n" - + "@dataclass \n" - + "class BaseRecord: \n" - + "\tall: Table \n\n" - + "\t@property \n" - + "\tdef schedule_number(self) -> Column[Integer]: ... \n\n" - + "\t@property \n" - + "\tdef line_number(self) -> Column[Integer]: ... \n\n" - ) - for name in self.tables: - columns = [column.name for column in self.tables[name].columns] - text += ( - f"_{(rr := name.split('_')[-1].upper())}_columns = Table( \n" - + f"\t'{name}', \n" - + "\tmetadata, \n" - + "".join([f"\tColumn('{column}'), \n" for column in columns]) - + ") \n\n" - ) - text += ( - "@dataclass \n" - + f"class _{rr}_base(BaseRecord): \n" - + "\tall: Table \n\n" - + "".join( - [ - ( - "\t@property \n" - + f"\tdef {column}(self) -> " - + f"Column[{'String' if column not in ['line_number', 'schedule_number'] else 'Integer'}]: \n" - + f"\t\treturn self.all.c.{column} \n\n" - ) - for column in columns - ] - ) - ) - text += f"{rr} = _{rr}_base(_{rr}_columns) \n\n" - path = Path(__file__).parent / "mca_record_types.py" - with open(path, "w") as wf: - _ = wf.write(text.replace("\t", " ")) - - @classmethod - def from_db_path(cls, db_path: Path | None = None): - db_path = validate_db_path(db_path) - return cls(engine=create_engine(url=f"sqlite:///{db_path.as_posix()}")) - - @classmethod - def default(cls): - return cls.from_db_path() - - def execute(self, query: Select[Any]) -> pd.DataFrame: # pyright: ignore[reportExplicitAny] - with self.engine.connect() as connection: - return pd.read_sql(query, connection) - - def _fetch_records_of_schedules( - self, - record_type: BaseRecord, - *schedule_numbers: int, - ) -> pd.DataFrame: - return self.execute( - select(record_type.all).where( - record_type.schedule_number.in_(schedule_numbers) - ) - ) - - def fetch_schedules(self, *schedule_numbers: int) -> dict[int, Schedule]: - bs = self._fetch_records_of_schedules(BS, *schedule_numbers) - bx = self._fetch_records_of_schedules(BX, *schedule_numbers) - lo = self._fetch_records_of_schedules(LO, *schedule_numbers) - li = self._fetch_records_of_schedules(LI, *schedule_numbers) - lt = self._fetch_records_of_schedules(LT, *schedule_numbers) - cr = self._fetch_records_of_schedules(CR, *schedule_numbers) - return { - sn: Schedule( - sn=sn, - bs=bs[bs.schedule_number == sn].iloc[0], - bx=bx[bx.schedule_number == sn].iloc[0], - loit=pd.concat( - [ - lo[lo.schedule_number == sn], - li[li.schedule_number == sn], - lt[lt.schedule_number == sn], - ] - ).reset_index(drop=True), - cr=cr[cr.schedule_number == sn], - ) - for sn in schedule_numbers - } - - -# Functions -def services_date_and_tiploc( - date: SixDate, - tiploc: str, - tt: Timetable | None = None, -): - tt = tt if tt is not None else Timetable() - on_date: Iterable[int] = tt.execute( - select(BS.schedule_number).where( - (BS.date_runs_from <= date) - & (BS.date_runs_to >= date) - & (BS.days_run.like(date.weekday_like)) - ) - ).schedule_number.values - origin: Iterable[int] = tt.execute( - select(LO.schedule_number).where(LO.location == f"{tiploc:<8}") - ).schedule_number.values - en_route: Iterable[int] = tt.execute( - select(LI.schedule_number).where( - (LI.location == f"{tiploc:<8}") - & (LI.scheduled_departure_time.not_like(" %")) - ) - ).schedule_number.values - destination: Iterable[int] = tt.execute( - select(LT.schedule_number).where(LT.location == f"{tiploc:<8}") - ).schedule_number.values - sns = np.unique([*origin, *en_route, *destination]) - sns = np.intersect1d(sns, on_date) # pyright: ignore[reportCallIssue, reportUnknownVariableType, reportArgumentType] - sns = [int(sn) for sn in sns] # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType] - return [ - Service( - date=date, - **schedule.__dict__, # pyright: ignore[reportAny] - ) - for _, schedule in tt.fetch_schedules(*sns).items() - ] - - -# Script -def main(): - print(s := services_date_and_tiploc(SixDate("260525"), "YORK")) - return len(s) - - -if __name__ == "__main__": - print(main()) diff --git a/src/national_rail_timetable/mca_record_types.py b/src/national_rail_timetable/mca_record_types.py deleted file mode 100644 index 3a164bb..0000000 --- a/src/national_rail_timetable/mca_record_types.py +++ /dev/null @@ -1,1002 +0,0 @@ -# This file is pre-generated for type-hinting while writing MCA file queries. -# Any changes made manually will likely be overwritten. -# It should not need to be generated more than once. -# If the RSP's timetable specification changes, then this will need to be updated. -# Result of mca_queries.py's Timetable._hardcode_table_dataclasses. - -# Imports -from dataclasses import dataclass, field -from typing import Any -from sqlalchemy import Column, MetaData, Table, String, Integer - -# Init. -metadata = MetaData() - -# Classes -@dataclass -class BaseRecord: - all: Table - - @property - def schedule_number(self) -> Column[Integer]: ... - - @property - def line_number(self) -> Column[Integer]: ... - -_BS_columns = Table( - 'raw_mca_bs', - metadata, - Column('record_identity'), - Column('transaction_type'), - Column('train_uid'), - Column('date_runs_from'), - Column('date_runs_to'), - Column('days_run'), - Column('bank_holiday_running'), - Column('train_status'), - Column('train_category'), - Column('train_identity'), - Column('headcode'), - Column('course_indicator'), - Column('profit_centre_code'), - Column('business_sector'), - Column('power_type'), - Column('timing_load'), - Column('speed'), - Column('operating_chars'), - Column('train_class'), - Column('sleepers'), - Column('reservations'), - Column('connect_indicator'), - Column('catering_code'), - Column('service_branding'), - Column('spare'), - Column('stp_indicator'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _BS_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def transaction_type(self) -> Column[String]: - return self.all.c.transaction_type - - @property - def train_uid(self) -> Column[String]: - return self.all.c.train_uid - - @property - def date_runs_from(self) -> Column[String]: - return self.all.c.date_runs_from - - @property - def date_runs_to(self) -> Column[String]: - return self.all.c.date_runs_to - - @property - def days_run(self) -> Column[String]: - return self.all.c.days_run - - @property - def bank_holiday_running(self) -> Column[String]: - return self.all.c.bank_holiday_running - - @property - def train_status(self) -> Column[String]: - return self.all.c.train_status - - @property - def train_category(self) -> Column[String]: - return self.all.c.train_category - - @property - def train_identity(self) -> Column[String]: - return self.all.c.train_identity - - @property - def headcode(self) -> Column[String]: - return self.all.c.headcode - - @property - def course_indicator(self) -> Column[String]: - return self.all.c.course_indicator - - @property - def profit_centre_code(self) -> Column[String]: - return self.all.c.profit_centre_code - - @property - def business_sector(self) -> Column[String]: - return self.all.c.business_sector - - @property - def power_type(self) -> Column[String]: - return self.all.c.power_type - - @property - def timing_load(self) -> Column[String]: - return self.all.c.timing_load - - @property - def speed(self) -> Column[String]: - return self.all.c.speed - - @property - def operating_chars(self) -> Column[String]: - return self.all.c.operating_chars - - @property - def train_class(self) -> Column[String]: - return self.all.c.train_class - - @property - def sleepers(self) -> Column[String]: - return self.all.c.sleepers - - @property - def reservations(self) -> Column[String]: - return self.all.c.reservations - - @property - def connect_indicator(self) -> Column[String]: - return self.all.c.connect_indicator - - @property - def catering_code(self) -> Column[String]: - return self.all.c.catering_code - - @property - def service_branding(self) -> Column[String]: - return self.all.c.service_branding - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def stp_indicator(self) -> Column[String]: - return self.all.c.stp_indicator - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -BS = _BS_base(_BS_columns) - -_HD_columns = Table( - 'raw_mca_hd', - metadata, - Column('record_identity'), - Column('file_identity'), - Column('date_of_extract'), - Column('time_of_extract'), - Column('current_file_reference'), - Column('last_file_reference'), - Column('update_indicator'), - Column('version'), - Column('extract_start_date'), - Column('extract_end_date'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _HD_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def file_identity(self) -> Column[String]: - return self.all.c.file_identity - - @property - def date_of_extract(self) -> Column[String]: - return self.all.c.date_of_extract - - @property - def time_of_extract(self) -> Column[String]: - return self.all.c.time_of_extract - - @property - def current_file_reference(self) -> Column[String]: - return self.all.c.current_file_reference - - @property - def last_file_reference(self) -> Column[String]: - return self.all.c.last_file_reference - - @property - def update_indicator(self) -> Column[String]: - return self.all.c.update_indicator - - @property - def version(self) -> Column[String]: - return self.all.c.version - - @property - def extract_start_date(self) -> Column[String]: - return self.all.c.extract_start_date - - @property - def extract_end_date(self) -> Column[String]: - return self.all.c.extract_end_date - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -HD = _HD_base(_HD_columns) - -_ZZ_columns = Table( - 'raw_mca_zz', - metadata, - Column('record_identity'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _ZZ_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -ZZ = _ZZ_base(_ZZ_columns) - -_TA_columns = Table( - 'raw_mca_ta', - metadata, - Column('record_identity'), - Column('tiploc_code'), - Column('capitals'), - Column('national_location_code'), - Column('nlc_check_character'), - Column('tps_description'), - Column('stanox'), - Column('po_mcp_code'), - Column('crs_code'), - Column('description'), - Column('new_tiploc'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _TA_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def tiploc_code(self) -> Column[String]: - return self.all.c.tiploc_code - - @property - def capitals(self) -> Column[String]: - return self.all.c.capitals - - @property - def national_location_code(self) -> Column[String]: - return self.all.c.national_location_code - - @property - def nlc_check_character(self) -> Column[String]: - return self.all.c.nlc_check_character - - @property - def tps_description(self) -> Column[String]: - return self.all.c.tps_description - - @property - def stanox(self) -> Column[String]: - return self.all.c.stanox - - @property - def po_mcp_code(self) -> Column[String]: - return self.all.c.po_mcp_code - - @property - def crs_code(self) -> Column[String]: - return self.all.c.crs_code - - @property - def description(self) -> Column[String]: - return self.all.c.description - - @property - def new_tiploc(self) -> Column[String]: - return self.all.c.new_tiploc - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -TA = _TA_base(_TA_columns) - -_CR_columns = Table( - 'raw_mca_cr', - metadata, - Column('record_identity'), - Column('location'), - Column('train_category'), - Column('train_identity'), - Column('headcode'), - Column('course_indicator'), - Column('profit_centre_code'), - Column('business_sector'), - Column('power_type'), - Column('timing_load'), - Column('speed'), - Column('operating_chars'), - Column('train_class'), - Column('sleepers'), - Column('reservations'), - Column('connect_indicator'), - Column('catering_code'), - Column('service_branding'), - Column('traction_class'), - Column('uic_code'), - Column('retail_service_id'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _CR_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def location(self) -> Column[String]: - return self.all.c.location - - @property - def train_category(self) -> Column[String]: - return self.all.c.train_category - - @property - def train_identity(self) -> Column[String]: - return self.all.c.train_identity - - @property - def headcode(self) -> Column[String]: - return self.all.c.headcode - - @property - def course_indicator(self) -> Column[String]: - return self.all.c.course_indicator - - @property - def profit_centre_code(self) -> Column[String]: - return self.all.c.profit_centre_code - - @property - def business_sector(self) -> Column[String]: - return self.all.c.business_sector - - @property - def power_type(self) -> Column[String]: - return self.all.c.power_type - - @property - def timing_load(self) -> Column[String]: - return self.all.c.timing_load - - @property - def speed(self) -> Column[String]: - return self.all.c.speed - - @property - def operating_chars(self) -> Column[String]: - return self.all.c.operating_chars - - @property - def train_class(self) -> Column[String]: - return self.all.c.train_class - - @property - def sleepers(self) -> Column[String]: - return self.all.c.sleepers - - @property - def reservations(self) -> Column[String]: - return self.all.c.reservations - - @property - def connect_indicator(self) -> Column[String]: - return self.all.c.connect_indicator - - @property - def catering_code(self) -> Column[String]: - return self.all.c.catering_code - - @property - def service_branding(self) -> Column[String]: - return self.all.c.service_branding - - @property - def traction_class(self) -> Column[String]: - return self.all.c.traction_class - - @property - def uic_code(self) -> Column[String]: - return self.all.c.uic_code - - @property - def retail_service_id(self) -> Column[String]: - return self.all.c.retail_service_id - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -CR = _CR_base(_CR_columns) - -_LT_columns = Table( - 'raw_mca_lt', - metadata, - Column('record_identity'), - Column('location'), - Column('scheduled_arrival_time'), - Column('public_arrival_time'), - Column('platform'), - Column('path'), - Column('activity'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _LT_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def location(self) -> Column[String]: - return self.all.c.location - - @property - def scheduled_arrival_time(self) -> Column[String]: - return self.all.c.scheduled_arrival_time - - @property - def public_arrival_time(self) -> Column[String]: - return self.all.c.public_arrival_time - - @property - def platform(self) -> Column[String]: - return self.all.c.platform - - @property - def path(self) -> Column[String]: - return self.all.c.path - - @property - def activity(self) -> Column[String]: - return self.all.c.activity - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -LT = _LT_base(_LT_columns) - -_LI_columns = Table( - 'raw_mca_li', - metadata, - Column('record_identity'), - Column('location'), - Column('scheduled_arrival_time'), - Column('scheduled_departure_time'), - Column('scheduled_pass'), - Column('public_arrival'), - Column('public_departure'), - Column('platform'), - Column('line'), - Column('path'), - Column('activity'), - Column('engineering_allowance'), - Column('pathing_allowance'), - Column('performance_allowance'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _LI_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def location(self) -> Column[String]: - return self.all.c.location - - @property - def scheduled_arrival_time(self) -> Column[String]: - return self.all.c.scheduled_arrival_time - - @property - def scheduled_departure_time(self) -> Column[String]: - return self.all.c.scheduled_departure_time - - @property - def scheduled_pass(self) -> Column[String]: - return self.all.c.scheduled_pass - - @property - def public_arrival(self) -> Column[String]: - return self.all.c.public_arrival - - @property - def public_departure(self) -> Column[String]: - return self.all.c.public_departure - - @property - def platform(self) -> Column[String]: - return self.all.c.platform - - @property - def line(self) -> Column[String]: - return self.all.c.line - - @property - def path(self) -> Column[String]: - return self.all.c.path - - @property - def activity(self) -> Column[String]: - return self.all.c.activity - - @property - def engineering_allowance(self) -> Column[String]: - return self.all.c.engineering_allowance - - @property - def pathing_allowance(self) -> Column[String]: - return self.all.c.pathing_allowance - - @property - def performance_allowance(self) -> Column[String]: - return self.all.c.performance_allowance - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -LI = _LI_base(_LI_columns) - -_TD_columns = Table( - 'raw_mca_td', - metadata, - Column('record_identity'), - Column('tiploc_code'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _TD_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def tiploc_code(self) -> Column[String]: - return self.all.c.tiploc_code - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -TD = _TD_base(_TD_columns) - -_AA_columns = Table( - 'raw_mca_aa', - metadata, - Column('record_identity'), - Column('transaction_type'), - Column('base_uid'), - Column('assoc_uid'), - Column('assoc_start_date'), - Column('assoc_end_date'), - Column('assoc_days'), - Column('assoc_cat'), - Column('assoc_date_ind'), - Column('assoc_location'), - Column('base_location_suffix'), - Column('assoc_location_suffix'), - Column('diagram_type'), - Column('association_type'), - Column('filler'), - Column('stp_indicator'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _AA_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def transaction_type(self) -> Column[String]: - return self.all.c.transaction_type - - @property - def base_uid(self) -> Column[String]: - return self.all.c.base_uid - - @property - def assoc_uid(self) -> Column[String]: - return self.all.c.assoc_uid - - @property - def assoc_start_date(self) -> Column[String]: - return self.all.c.assoc_start_date - - @property - def assoc_end_date(self) -> Column[String]: - return self.all.c.assoc_end_date - - @property - def assoc_days(self) -> Column[String]: - return self.all.c.assoc_days - - @property - def assoc_cat(self) -> Column[String]: - return self.all.c.assoc_cat - - @property - def assoc_date_ind(self) -> Column[String]: - return self.all.c.assoc_date_ind - - @property - def assoc_location(self) -> Column[String]: - return self.all.c.assoc_location - - @property - def base_location_suffix(self) -> Column[String]: - return self.all.c.base_location_suffix - - @property - def assoc_location_suffix(self) -> Column[String]: - return self.all.c.assoc_location_suffix - - @property - def diagram_type(self) -> Column[String]: - return self.all.c.diagram_type - - @property - def association_type(self) -> Column[String]: - return self.all.c.association_type - - @property - def filler(self) -> Column[String]: - return self.all.c.filler - - @property - def stp_indicator(self) -> Column[String]: - return self.all.c.stp_indicator - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -AA = _AA_base(_AA_columns) - -_LO_columns = Table( - 'raw_mca_lo', - metadata, - Column('record_identity'), - Column('location'), - Column('scheduled_departure_time'), - Column('public_departure_time'), - Column('platform'), - Column('line'), - Column('engineering_allowance'), - Column('pathing_allowance'), - Column('activity'), - Column('performance_allowance'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _LO_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def location(self) -> Column[String]: - return self.all.c.location - - @property - def scheduled_departure_time(self) -> Column[String]: - return self.all.c.scheduled_departure_time - - @property - def public_departure_time(self) -> Column[String]: - return self.all.c.public_departure_time - - @property - def platform(self) -> Column[String]: - return self.all.c.platform - - @property - def line(self) -> Column[String]: - return self.all.c.line - - @property - def engineering_allowance(self) -> Column[String]: - return self.all.c.engineering_allowance - - @property - def pathing_allowance(self) -> Column[String]: - return self.all.c.pathing_allowance - - @property - def activity(self) -> Column[String]: - return self.all.c.activity - - @property - def performance_allowance(self) -> Column[String]: - return self.all.c.performance_allowance - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -LO = _LO_base(_LO_columns) - -_BX_columns = Table( - 'raw_mca_bx', - metadata, - Column('record_identity'), - Column('traction_class'), - Column('uic_code'), - Column('atoc_code'), - Column('applicable_timetable_code'), - Column('retail_service_id'), - Column('source'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _BX_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def traction_class(self) -> Column[String]: - return self.all.c.traction_class - - @property - def uic_code(self) -> Column[String]: - return self.all.c.uic_code - - @property - def atoc_code(self) -> Column[String]: - return self.all.c.atoc_code - - @property - def applicable_timetable_code(self) -> Column[String]: - return self.all.c.applicable_timetable_code - - @property - def retail_service_id(self) -> Column[String]: - return self.all.c.retail_service_id - - @property - def source(self) -> Column[String]: - return self.all.c.source - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -BX = _BX_base(_BX_columns) - -_TI_columns = Table( - 'raw_mca_ti', - metadata, - Column('record_identity'), - Column('tiploc_code'), - Column('capitals'), - Column('national_location_code'), - Column('nlc_check_character'), - Column('tps_description'), - Column('stanox'), - Column('po_mcp_code'), - Column('crs_code'), - Column('description'), - Column('spare'), - Column('line_number'), - Column('schedule_number'), -) - -@dataclass -class _TI_base(BaseRecord): - all: Table - - @property - def record_identity(self) -> Column[String]: - return self.all.c.record_identity - - @property - def tiploc_code(self) -> Column[String]: - return self.all.c.tiploc_code - - @property - def capitals(self) -> Column[String]: - return self.all.c.capitals - - @property - def national_location_code(self) -> Column[String]: - return self.all.c.national_location_code - - @property - def nlc_check_character(self) -> Column[String]: - return self.all.c.nlc_check_character - - @property - def tps_description(self) -> Column[String]: - return self.all.c.tps_description - - @property - def stanox(self) -> Column[String]: - return self.all.c.stanox - - @property - def po_mcp_code(self) -> Column[String]: - return self.all.c.po_mcp_code - - @property - def crs_code(self) -> Column[String]: - return self.all.c.crs_code - - @property - def description(self) -> Column[String]: - return self.all.c.description - - @property - def spare(self) -> Column[String]: - return self.all.c.spare - - @property - def line_number(self) -> Column[Integer]: - return self.all.c.line_number - - @property - def schedule_number(self) -> Column[Integer]: - return self.all.c.schedule_number - -TI = _TI_base(_TI_columns) - diff --git a/src/national_rail_timetable/parsing.py b/src/national_rail_timetable/parsing.py index 8dbb5a3..9f20020 100644 --- a/src/national_rail_timetable/parsing.py +++ b/src/national_rail_timetable/parsing.py @@ -9,17 +9,13 @@ Aimed primarily towards producing a reduced sqlite database. # pyright: reportUnknownLambdaType=false # Imports -import os -import sqlite3 from itertools import pairwise from pathlib import Path -from zipfile import ZipFile import numpy as np import pandas as pd from pypdf import PageObject, PdfReader -from national_rail_timetable.nr_requests import fetch_nr_timetable_files # Init. SPECIFICATION_TABLE_LOCATIONS = { @@ -37,7 +33,6 @@ SPECIFICATION_TABLE_LOCATIONS = { "MCA_ZZ": (23, 1), } DEFAULT_RAW_SPEC_DATA_DIR = Path(__file__).parents[2] / "data/specification_tables" -DEFAULT_DB_PATH = Path.home() / ".cache/nr_data/timetable.db" # Functions @@ -151,165 +146,3 @@ def read_specification_table_raws( continue tables[path.name[:-4]] = pd.read_csv(path) return tables - - -def validate_db_path(db_path: Path | None = None): - db_path = ( - db_path - if db_path is not None - else Path(os.environ.get("NR_DATADIR", DEFAULT_DB_PATH)) - ) - db_path.parent.mkdir(exist_ok=True, parents=True) - return db_path - - -def create_mca_specification_dbtables( - tables: dict[str, pd.DataFrame], - db_path: Path | None = None, -): - db_path = validate_db_path(db_path) - connection = sqlite3.connect(db_path) - cursor = connection.cursor() - for name, df in tables.items(): - if (_n := name.split("_"))[0] != "MCA" or len(_n) != 2: - continue - df["Start Index"] = df["Position"].apply(lambda s: int(s.split("-")[0]) - 1) - df["End Index"] = df["Position"].apply(lambda s: int(s.split("-")[-1])) - - _ = cursor.execute(f"DROP TABLE IF EXISTS spec_{name.lower()}") - _ = cursor.execute( - f""" - CREATE TABLE spec_{name.lower()} - ({", ".join([col.lower().replace(" ", "_") for col in df.columns])}) - """, - ) - _ = cursor.executemany( - f""" - INSERT INTO spec_{name.lower()} - VALUES({", ".join(["?" for _ in df.columns])}) - """, - [tuple(row.values) for _, row in df.iterrows()], - ) - connection.commit() - connection.close() - return db_path - - -# TODO: There is no need for this to take minutes, investigate speed ups. -def create_mca_raw_dbtables( - zipfile: ZipFile | None = None, - db_path: Path | None = None, - allow_fetch: bool = True, - print_progress: bool = True, -) -> dict[str, str]: - db_path = validate_db_path(db_path) - if zipfile is None: - if allow_fetch: - zipfile = fetch_nr_timetable_files() - else: - raise RuntimeError( - "There was no zipfile provided and allow_fetch is set to False. " - + "Please either allow automatic fetching, or supply the zipfile argument. " - + "The package's fetching function is fetch_nr_timetable_files. " - ) - - connection = sqlite3.connect(db_path) - cursor = connection.cursor() - tables = [ - row[0] - for row in cursor.execute( - "SELECT name FROM sqlite_master WHERE type = 'table'" - ).fetchall() - if row[0][:9] == "spec_mca_" - ] - if len(tables) == 0: - raise FileNotFoundError( - "No spec_mca_... tables found in given database. " - + "Please ensure create_mca_specification_dbtables has been run successfully. " - ) - mappings = {} - all_start_indexes = {} - all_end_indexes = {} - for name in tables: - mappings[name] = "raw_" + name[5:] - spec = pd.DataFrame( - cursor.execute(f"SELECT * FROM {name}").fetchall(), - columns=[d[0] for d in cursor.description], - ) - all_start_indexes[mappings[name]] = spec.start_index.values - all_end_indexes[mappings[name]] = spec.end_index.values - new_columns = [ - col.split("/")[0].lower().replace(" ", "_").replace("-", "_") - for col in spec.field_description - ] + ["line_number", "schedule_number"] - _ = cursor.execute(f"DROP TABLE IF EXISTS {mappings[name].lower()}") - _ = cursor.execute( - f""" - CREATE TABLE {mappings[name]} - ({(", ".join(new_columns))}) - """, - ) - connection.commit() - - zipinfo = [ - zipinfo - for name, zipinfo in zipfile.NameToInfo.items() - if name.split(".")[-1] == "MCA" - ][0] - file = zipfile.open(zipinfo) - schedule_number, line_number = -1, -1 - while (line := file.readline().decode()) != "": - line_number += 1 - record_type = line[:2] - schedule_number += int(record_type == "BS") - target_table = f"raw_mca_{record_type.lower()}" - start_indexes = all_start_indexes.get(target_table) - end_indexes = all_end_indexes.get(target_table) - if start_indexes is None or end_indexes is None: - continue - values = ", ".join( - [ - "'" + line[lb:ub].replace("'", "") + "'" - for lb, ub in zip(start_indexes, end_indexes, strict=True) - ] - + [f"{line_number}", f"{schedule_number}"] - ) - _ = cursor.execute(f"INSERT INTO {target_table} VALUES({values})") - if line_number % 3737 == 0 and print_progress: - print(f" {line_number:9,} {line[:-1]}", end="\r") - if print_progress: - print() - connection.commit() - connection.close() - return mappings - - -# Script -def main( - skip_pdf: bool = False, - pdf_spec_path: Path | None = None, - raw_spec_dir: Path | None = None, -): - if not skip_pdf: - try: - tables = extract_specification_document_tables(pdf_spec_path) - _ = store_specification_table_raws(tables, raw_spec_dir) - except FileNotFoundError: - pass - - try: - tables = read_specification_table_raws(raw_spec_dir) - except FileNotFoundError: - raise FileNotFoundError( - "The tables generated from the RSP's specification were not found. " - + "This means neither the cached version nor the original .pdf is available. " - + "Try suppling either to their default locations, or supplying custom directories. " - + "Manual fix: extract_specification_document_tables then store_specification_table_raws. " - ) - - _ = create_mca_specification_dbtables(tables) - return create_mca_raw_dbtables() - - -if __name__ == "__main__": - _ = main()