Removal of sqlite/sqlalchemy based approach - it is too slow to combine the results, even with in-memory database loading.

This commit is contained in:
2026-05-26 10:17:28 +01:00
parent 0479f1e4a8
commit e887cc791e
4 changed files with 2 additions and 1449 deletions
+2 -1
View File
@@ -11,7 +11,8 @@ dependencies = [
"pypdf (>=6.12.0,<7.0.0)",
"pandas (>=3.0.3,<4.0.0)",
"pandas-stubs (>=3.0.0.260204,<4.0.0.0)",
"sqlalchemy (>=2.0.49,<3.0.0)"
"sqlalchemy (>=2.0.49,<3.0.0)",
"py-spy (>=0.4.2,<0.5.0)"
]
-279
View File
@@ -1,279 +0,0 @@
"""
Queries for the 'raw_mca_...' tables generated in parsing.py.
Thus far, attempts at few assumptions for record types have been made.
These queries will outright expect certain properties of the databases generated.
If they suddenly stop working - it should be due to an RSP specification change.
Therefore, the error handling will be less graceful as this is not predictable.
It seems unlikely that the CIF format will be modified any time soon.
"""
# Imports
from collections.abc import Iterable
import sqlite3
from datetime import datetime
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import numpy as np
import pandas as pd
from sqlalchemy import (
Column,
Engine,
MetaData,
Select,
Table,
create_engine,
select,
)
from national_rail_timetable.mca_record_types import (
BS,
BX,
CR,
LI,
LO,
LT,
BaseRecord,
)
from national_rail_timetable.parsing import validate_db_path
# Classes
class SixDate(str):
def __post_init__(self):
assert len(self) == 6
assert self.isnumeric()
def as_unix(self):
return f"20{self[:2]}-{self[2:4]}-{self[-2:]}"
def as_numpy(self):
return np.datetime64(self.as_unix())
def as_datetime(self):
return datetime.fromisoformat(self.as_unix())
@classmethod
def from_unix(cls, unix_str: str):
return cls(unix_str[2:].replace("-", "").replace("/", ""))
@classmethod
def from_numpy(cls, np_dt: np.datetime64):
return cls.from_unix(str(np_dt).split("T")[0])
@classmethod
def from_datetime(cls, dt: datetime):
return cls.from_unix(str(dt).split(" ")[0])
@property
def weekday(self) -> int:
return self.as_datetime().weekday()
@property
def weekday_like(self) -> str:
return "_" * self.weekday + "1%"
@dataclass
class Schedule:
sn: int
bs: pd.Series
bx: pd.Series
loit: pd.DataFrame
cr: pd.DataFrame
@dataclass
class Service(Schedule):
date: SixDate
@dataclass
class Timetable:
engine: Engine = field(
default_factory=lambda: create_engine(
url=f"sqlite:///{validate_db_path().as_posix()}"
)
)
metadata: MetaData = field(default_factory=MetaData)
tables: dict[str, Table] = field(default_factory=dict)
def __post_init__(self):
self._populate_self_tables()
def _populate_self_tables(self):
cursor = (connection := self._generate_sqlite3_connection()).cursor()
self.tables |= {
name: Table(
name, # pyright: ignore[reportAny]
self.metadata,
*[
Column(d[0])
for d in cursor.execute(f"SELECT * FROM {name} LIMIT 1").description
],
)
for name in [ # pyright: ignore[reportAny]
row[0]
for row in cursor.execute( # pyright: ignore[reportAny]
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
if row[0][:8] == "raw_mca_"
]
}
connection.close()
def _generate_sqlite3_connection(self) -> sqlite3.Connection:
return sqlite3.connect(self.engine.url.__to_string__().split("///")[1])
# TODO: Implement docstrings from 'spec_mca_...' tables
def _hardcode_table_dataclasses(self):
text: str = (
"# This file is pre-generated for type-hinting while writing MCA file queries. \n"
+ "# Any changes made manually will likely be overwritten. \n"
+ "# It should not need to be generated more than once. \n"
+ "# If the RSP's timetable specification changes, then this will need to be updated. \n"
+ "# Result of mca_queries.py's Timetable._hardcode_table_dataclasses. \n"
+ "\n"
+ "# Imports \n"
+ "from dataclasses import dataclass, field \n"
+ "from typing import Any \n"
+ "from sqlalchemy import Column, MetaData, Table, String, Integer \n"
+ "\n"
+ "# Init. \n"
+ "metadata = MetaData() \n"
+ "\n"
+ "# Classes \n"
+ "@dataclass \n"
+ "class BaseRecord: \n"
+ "\tall: Table \n\n"
+ "\t@property \n"
+ "\tdef schedule_number(self) -> Column[Integer]: ... \n\n"
+ "\t@property \n"
+ "\tdef line_number(self) -> Column[Integer]: ... \n\n"
)
for name in self.tables:
columns = [column.name for column in self.tables[name].columns]
text += (
f"_{(rr := name.split('_')[-1].upper())}_columns = Table( \n"
+ f"\t'{name}', \n"
+ "\tmetadata, \n"
+ "".join([f"\tColumn('{column}'), \n" for column in columns])
+ ") \n\n"
)
text += (
"@dataclass \n"
+ f"class _{rr}_base(BaseRecord): \n"
+ "\tall: Table \n\n"
+ "".join(
[
(
"\t@property \n"
+ f"\tdef {column}(self) -> "
+ f"Column[{'String' if column not in ['line_number', 'schedule_number'] else 'Integer'}]: \n"
+ f"\t\treturn self.all.c.{column} \n\n"
)
for column in columns
]
)
)
text += f"{rr} = _{rr}_base(_{rr}_columns) \n\n"
path = Path(__file__).parent / "mca_record_types.py"
with open(path, "w") as wf:
_ = wf.write(text.replace("\t", " "))
@classmethod
def from_db_path(cls, db_path: Path | None = None):
db_path = validate_db_path(db_path)
return cls(engine=create_engine(url=f"sqlite:///{db_path.as_posix()}"))
@classmethod
def default(cls):
return cls.from_db_path()
def execute(self, query: Select[Any]) -> pd.DataFrame: # pyright: ignore[reportExplicitAny]
with self.engine.connect() as connection:
return pd.read_sql(query, connection)
def _fetch_records_of_schedules(
self,
record_type: BaseRecord,
*schedule_numbers: int,
) -> pd.DataFrame:
return self.execute(
select(record_type.all).where(
record_type.schedule_number.in_(schedule_numbers)
)
)
def fetch_schedules(self, *schedule_numbers: int) -> dict[int, Schedule]:
bs = self._fetch_records_of_schedules(BS, *schedule_numbers)
bx = self._fetch_records_of_schedules(BX, *schedule_numbers)
lo = self._fetch_records_of_schedules(LO, *schedule_numbers)
li = self._fetch_records_of_schedules(LI, *schedule_numbers)
lt = self._fetch_records_of_schedules(LT, *schedule_numbers)
cr = self._fetch_records_of_schedules(CR, *schedule_numbers)
return {
sn: Schedule(
sn=sn,
bs=bs[bs.schedule_number == sn].iloc[0],
bx=bx[bx.schedule_number == sn].iloc[0],
loit=pd.concat(
[
lo[lo.schedule_number == sn],
li[li.schedule_number == sn],
lt[lt.schedule_number == sn],
]
).reset_index(drop=True),
cr=cr[cr.schedule_number == sn],
)
for sn in schedule_numbers
}
# Functions
def services_date_and_tiploc(
date: SixDate,
tiploc: str,
tt: Timetable | None = None,
):
tt = tt if tt is not None else Timetable()
on_date: Iterable[int] = tt.execute(
select(BS.schedule_number).where(
(BS.date_runs_from <= date)
& (BS.date_runs_to >= date)
& (BS.days_run.like(date.weekday_like))
)
).schedule_number.values
origin: Iterable[int] = tt.execute(
select(LO.schedule_number).where(LO.location == f"{tiploc:<8}")
).schedule_number.values
en_route: Iterable[int] = tt.execute(
select(LI.schedule_number).where(
(LI.location == f"{tiploc:<8}")
& (LI.scheduled_departure_time.not_like(" %"))
)
).schedule_number.values
destination: Iterable[int] = tt.execute(
select(LT.schedule_number).where(LT.location == f"{tiploc:<8}")
).schedule_number.values
sns = np.unique([*origin, *en_route, *destination])
sns = np.intersect1d(sns, on_date) # pyright: ignore[reportCallIssue, reportUnknownVariableType, reportArgumentType]
sns = [int(sn) for sn in sns] # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
return [
Service(
date=date,
**schedule.__dict__, # pyright: ignore[reportAny]
)
for _, schedule in tt.fetch_schedules(*sns).items()
]
# Script
def main():
print(s := services_date_and_tiploc(SixDate("260525"), "YORK"))
return len(s)
if __name__ == "__main__":
print(main())
File diff suppressed because it is too large Load Diff
-167
View File
@@ -9,17 +9,13 @@ Aimed primarily towards producing a reduced sqlite database.
# pyright: reportUnknownLambdaType=false
# Imports
import os
import sqlite3
from itertools import pairwise
from pathlib import Path
from zipfile import ZipFile
import numpy as np
import pandas as pd
from pypdf import PageObject, PdfReader
from national_rail_timetable.nr_requests import fetch_nr_timetable_files
# Init.
SPECIFICATION_TABLE_LOCATIONS = {
@@ -37,7 +33,6 @@ SPECIFICATION_TABLE_LOCATIONS = {
"MCA_ZZ": (23, 1),
}
DEFAULT_RAW_SPEC_DATA_DIR = Path(__file__).parents[2] / "data/specification_tables"
DEFAULT_DB_PATH = Path.home() / ".cache/nr_data/timetable.db"
# Functions
@@ -151,165 +146,3 @@ def read_specification_table_raws(
continue
tables[path.name[:-4]] = pd.read_csv(path)
return tables
def validate_db_path(db_path: Path | None = None):
db_path = (
db_path
if db_path is not None
else Path(os.environ.get("NR_DATADIR", DEFAULT_DB_PATH))
)
db_path.parent.mkdir(exist_ok=True, parents=True)
return db_path
def create_mca_specification_dbtables(
tables: dict[str, pd.DataFrame],
db_path: Path | None = None,
):
db_path = validate_db_path(db_path)
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
for name, df in tables.items():
if (_n := name.split("_"))[0] != "MCA" or len(_n) != 2:
continue
df["Start Index"] = df["Position"].apply(lambda s: int(s.split("-")[0]) - 1)
df["End Index"] = df["Position"].apply(lambda s: int(s.split("-")[-1]))
_ = cursor.execute(f"DROP TABLE IF EXISTS spec_{name.lower()}")
_ = cursor.execute(
f"""
CREATE TABLE spec_{name.lower()}
({", ".join([col.lower().replace(" ", "_") for col in df.columns])})
""",
)
_ = cursor.executemany(
f"""
INSERT INTO spec_{name.lower()}
VALUES({", ".join(["?" for _ in df.columns])})
""",
[tuple(row.values) for _, row in df.iterrows()],
)
connection.commit()
connection.close()
return db_path
# TODO: There is no need for this to take minutes, investigate speed ups.
def create_mca_raw_dbtables(
zipfile: ZipFile | None = None,
db_path: Path | None = None,
allow_fetch: bool = True,
print_progress: bool = True,
) -> dict[str, str]:
db_path = validate_db_path(db_path)
if zipfile is None:
if allow_fetch:
zipfile = fetch_nr_timetable_files()
else:
raise RuntimeError(
"There was no zipfile provided and allow_fetch is set to False. "
+ "Please either allow automatic fetching, or supply the zipfile argument. "
+ "The package's fetching function is fetch_nr_timetable_files. "
)
connection = sqlite3.connect(db_path)
cursor = connection.cursor()
tables = [
row[0]
for row in cursor.execute(
"SELECT name FROM sqlite_master WHERE type = 'table'"
).fetchall()
if row[0][:9] == "spec_mca_"
]
if len(tables) == 0:
raise FileNotFoundError(
"No spec_mca_... tables found in given database. "
+ "Please ensure create_mca_specification_dbtables has been run successfully. "
)
mappings = {}
all_start_indexes = {}
all_end_indexes = {}
for name in tables:
mappings[name] = "raw_" + name[5:]
spec = pd.DataFrame(
cursor.execute(f"SELECT * FROM {name}").fetchall(),
columns=[d[0] for d in cursor.description],
)
all_start_indexes[mappings[name]] = spec.start_index.values
all_end_indexes[mappings[name]] = spec.end_index.values
new_columns = [
col.split("/")[0].lower().replace(" ", "_").replace("-", "_")
for col in spec.field_description
] + ["line_number", "schedule_number"]
_ = cursor.execute(f"DROP TABLE IF EXISTS {mappings[name].lower()}")
_ = cursor.execute(
f"""
CREATE TABLE {mappings[name]}
({(", ".join(new_columns))})
""",
)
connection.commit()
zipinfo = [
zipinfo
for name, zipinfo in zipfile.NameToInfo.items()
if name.split(".")[-1] == "MCA"
][0]
file = zipfile.open(zipinfo)
schedule_number, line_number = -1, -1
while (line := file.readline().decode()) != "":
line_number += 1
record_type = line[:2]
schedule_number += int(record_type == "BS")
target_table = f"raw_mca_{record_type.lower()}"
start_indexes = all_start_indexes.get(target_table)
end_indexes = all_end_indexes.get(target_table)
if start_indexes is None or end_indexes is None:
continue
values = ", ".join(
[
"'" + line[lb:ub].replace("'", "") + "'"
for lb, ub in zip(start_indexes, end_indexes, strict=True)
]
+ [f"{line_number}", f"{schedule_number}"]
)
_ = cursor.execute(f"INSERT INTO {target_table} VALUES({values})")
if line_number % 3737 == 0 and print_progress:
print(f" {line_number:9,} {line[:-1]}", end="\r")
if print_progress:
print()
connection.commit()
connection.close()
return mappings
# Script
def main(
skip_pdf: bool = False,
pdf_spec_path: Path | None = None,
raw_spec_dir: Path | None = None,
):
if not skip_pdf:
try:
tables = extract_specification_document_tables(pdf_spec_path)
_ = store_specification_table_raws(tables, raw_spec_dir)
except FileNotFoundError:
pass
try:
tables = read_specification_table_raws(raw_spec_dir)
except FileNotFoundError:
raise FileNotFoundError(
"The tables generated from the RSP's specification were not found. "
+ "This means neither the cached version nor the original .pdf is available. "
+ "Try suppling either to their default locations, or supplying custom directories. "
+ "Manual fix: extract_specification_document_tables then store_specification_table_raws. "
)
_ = create_mca_specification_dbtables(tables)
return create_mca_raw_dbtables()
if __name__ == "__main__":
_ = main()