Removal of sqlite/sqlalchemy based approach - it is too slow to combine the results, even with in-memory database loading.
This commit is contained in:
+2
-1
@@ -11,7 +11,8 @@ dependencies = [
|
|||||||
"pypdf (>=6.12.0,<7.0.0)",
|
"pypdf (>=6.12.0,<7.0.0)",
|
||||||
"pandas (>=3.0.3,<4.0.0)",
|
"pandas (>=3.0.3,<4.0.0)",
|
||||||
"pandas-stubs (>=3.0.0.260204,<4.0.0.0)",
|
"pandas-stubs (>=3.0.0.260204,<4.0.0.0)",
|
||||||
"sqlalchemy (>=2.0.49,<3.0.0)"
|
"sqlalchemy (>=2.0.49,<3.0.0)",
|
||||||
|
"py-spy (>=0.4.2,<0.5.0)"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,279 +0,0 @@
|
|||||||
"""
|
|
||||||
Queries for the 'raw_mca_...' tables generated in parsing.py.
|
|
||||||
Thus far, attempts at few assumptions for record types have been made.
|
|
||||||
These queries will outright expect certain properties of the databases generated.
|
|
||||||
If they suddenly stop working - it should be due to an RSP specification change.
|
|
||||||
Therefore, the error handling will be less graceful as this is not predictable.
|
|
||||||
It seems unlikely that the CIF format will be modified any time soon.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Imports
|
|
||||||
from collections.abc import Iterable
|
|
||||||
import sqlite3
|
|
||||||
from datetime import datetime
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from sqlalchemy import (
|
|
||||||
Column,
|
|
||||||
Engine,
|
|
||||||
MetaData,
|
|
||||||
Select,
|
|
||||||
Table,
|
|
||||||
create_engine,
|
|
||||||
select,
|
|
||||||
)
|
|
||||||
|
|
||||||
from national_rail_timetable.mca_record_types import (
|
|
||||||
BS,
|
|
||||||
BX,
|
|
||||||
CR,
|
|
||||||
LI,
|
|
||||||
LO,
|
|
||||||
LT,
|
|
||||||
BaseRecord,
|
|
||||||
)
|
|
||||||
from national_rail_timetable.parsing import validate_db_path
|
|
||||||
|
|
||||||
|
|
||||||
# Classes
|
|
||||||
class SixDate(str):
|
|
||||||
def __post_init__(self):
|
|
||||||
assert len(self) == 6
|
|
||||||
assert self.isnumeric()
|
|
||||||
|
|
||||||
def as_unix(self):
|
|
||||||
return f"20{self[:2]}-{self[2:4]}-{self[-2:]}"
|
|
||||||
|
|
||||||
def as_numpy(self):
|
|
||||||
return np.datetime64(self.as_unix())
|
|
||||||
|
|
||||||
def as_datetime(self):
|
|
||||||
return datetime.fromisoformat(self.as_unix())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_unix(cls, unix_str: str):
|
|
||||||
return cls(unix_str[2:].replace("-", "").replace("/", ""))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_numpy(cls, np_dt: np.datetime64):
|
|
||||||
return cls.from_unix(str(np_dt).split("T")[0])
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_datetime(cls, dt: datetime):
|
|
||||||
return cls.from_unix(str(dt).split(" ")[0])
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weekday(self) -> int:
|
|
||||||
return self.as_datetime().weekday()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weekday_like(self) -> str:
|
|
||||||
return "_" * self.weekday + "1%"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Schedule:
|
|
||||||
sn: int
|
|
||||||
bs: pd.Series
|
|
||||||
bx: pd.Series
|
|
||||||
loit: pd.DataFrame
|
|
||||||
cr: pd.DataFrame
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Service(Schedule):
|
|
||||||
date: SixDate
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Timetable:
|
|
||||||
engine: Engine = field(
|
|
||||||
default_factory=lambda: create_engine(
|
|
||||||
url=f"sqlite:///{validate_db_path().as_posix()}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
metadata: MetaData = field(default_factory=MetaData)
|
|
||||||
tables: dict[str, Table] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
self._populate_self_tables()
|
|
||||||
|
|
||||||
def _populate_self_tables(self):
|
|
||||||
cursor = (connection := self._generate_sqlite3_connection()).cursor()
|
|
||||||
self.tables |= {
|
|
||||||
name: Table(
|
|
||||||
name, # pyright: ignore[reportAny]
|
|
||||||
self.metadata,
|
|
||||||
*[
|
|
||||||
Column(d[0])
|
|
||||||
for d in cursor.execute(f"SELECT * FROM {name} LIMIT 1").description
|
|
||||||
],
|
|
||||||
)
|
|
||||||
for name in [ # pyright: ignore[reportAny]
|
|
||||||
row[0]
|
|
||||||
for row in cursor.execute( # pyright: ignore[reportAny]
|
|
||||||
"SELECT name FROM sqlite_master WHERE type = 'table'"
|
|
||||||
).fetchall()
|
|
||||||
if row[0][:8] == "raw_mca_"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
def _generate_sqlite3_connection(self) -> sqlite3.Connection:
|
|
||||||
return sqlite3.connect(self.engine.url.__to_string__().split("///")[1])
|
|
||||||
|
|
||||||
# TODO: Implement docstrings from 'spec_mca_...' tables
|
|
||||||
def _hardcode_table_dataclasses(self):
|
|
||||||
text: str = (
|
|
||||||
"# This file is pre-generated for type-hinting while writing MCA file queries. \n"
|
|
||||||
+ "# Any changes made manually will likely be overwritten. \n"
|
|
||||||
+ "# It should not need to be generated more than once. \n"
|
|
||||||
+ "# If the RSP's timetable specification changes, then this will need to be updated. \n"
|
|
||||||
+ "# Result of mca_queries.py's Timetable._hardcode_table_dataclasses. \n"
|
|
||||||
+ "\n"
|
|
||||||
+ "# Imports \n"
|
|
||||||
+ "from dataclasses import dataclass, field \n"
|
|
||||||
+ "from typing import Any \n"
|
|
||||||
+ "from sqlalchemy import Column, MetaData, Table, String, Integer \n"
|
|
||||||
+ "\n"
|
|
||||||
+ "# Init. \n"
|
|
||||||
+ "metadata = MetaData() \n"
|
|
||||||
+ "\n"
|
|
||||||
+ "# Classes \n"
|
|
||||||
+ "@dataclass \n"
|
|
||||||
+ "class BaseRecord: \n"
|
|
||||||
+ "\tall: Table \n\n"
|
|
||||||
+ "\t@property \n"
|
|
||||||
+ "\tdef schedule_number(self) -> Column[Integer]: ... \n\n"
|
|
||||||
+ "\t@property \n"
|
|
||||||
+ "\tdef line_number(self) -> Column[Integer]: ... \n\n"
|
|
||||||
)
|
|
||||||
for name in self.tables:
|
|
||||||
columns = [column.name for column in self.tables[name].columns]
|
|
||||||
text += (
|
|
||||||
f"_{(rr := name.split('_')[-1].upper())}_columns = Table( \n"
|
|
||||||
+ f"\t'{name}', \n"
|
|
||||||
+ "\tmetadata, \n"
|
|
||||||
+ "".join([f"\tColumn('{column}'), \n" for column in columns])
|
|
||||||
+ ") \n\n"
|
|
||||||
)
|
|
||||||
text += (
|
|
||||||
"@dataclass \n"
|
|
||||||
+ f"class _{rr}_base(BaseRecord): \n"
|
|
||||||
+ "\tall: Table \n\n"
|
|
||||||
+ "".join(
|
|
||||||
[
|
|
||||||
(
|
|
||||||
"\t@property \n"
|
|
||||||
+ f"\tdef {column}(self) -> "
|
|
||||||
+ f"Column[{'String' if column not in ['line_number', 'schedule_number'] else 'Integer'}]: \n"
|
|
||||||
+ f"\t\treturn self.all.c.{column} \n\n"
|
|
||||||
)
|
|
||||||
for column in columns
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
text += f"{rr} = _{rr}_base(_{rr}_columns) \n\n"
|
|
||||||
path = Path(__file__).parent / "mca_record_types.py"
|
|
||||||
with open(path, "w") as wf:
|
|
||||||
_ = wf.write(text.replace("\t", " "))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_db_path(cls, db_path: Path | None = None):
|
|
||||||
db_path = validate_db_path(db_path)
|
|
||||||
return cls(engine=create_engine(url=f"sqlite:///{db_path.as_posix()}"))
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def default(cls):
|
|
||||||
return cls.from_db_path()
|
|
||||||
|
|
||||||
def execute(self, query: Select[Any]) -> pd.DataFrame: # pyright: ignore[reportExplicitAny]
|
|
||||||
with self.engine.connect() as connection:
|
|
||||||
return pd.read_sql(query, connection)
|
|
||||||
|
|
||||||
def _fetch_records_of_schedules(
|
|
||||||
self,
|
|
||||||
record_type: BaseRecord,
|
|
||||||
*schedule_numbers: int,
|
|
||||||
) -> pd.DataFrame:
|
|
||||||
return self.execute(
|
|
||||||
select(record_type.all).where(
|
|
||||||
record_type.schedule_number.in_(schedule_numbers)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def fetch_schedules(self, *schedule_numbers: int) -> dict[int, Schedule]:
|
|
||||||
bs = self._fetch_records_of_schedules(BS, *schedule_numbers)
|
|
||||||
bx = self._fetch_records_of_schedules(BX, *schedule_numbers)
|
|
||||||
lo = self._fetch_records_of_schedules(LO, *schedule_numbers)
|
|
||||||
li = self._fetch_records_of_schedules(LI, *schedule_numbers)
|
|
||||||
lt = self._fetch_records_of_schedules(LT, *schedule_numbers)
|
|
||||||
cr = self._fetch_records_of_schedules(CR, *schedule_numbers)
|
|
||||||
return {
|
|
||||||
sn: Schedule(
|
|
||||||
sn=sn,
|
|
||||||
bs=bs[bs.schedule_number == sn].iloc[0],
|
|
||||||
bx=bx[bx.schedule_number == sn].iloc[0],
|
|
||||||
loit=pd.concat(
|
|
||||||
[
|
|
||||||
lo[lo.schedule_number == sn],
|
|
||||||
li[li.schedule_number == sn],
|
|
||||||
lt[lt.schedule_number == sn],
|
|
||||||
]
|
|
||||||
).reset_index(drop=True),
|
|
||||||
cr=cr[cr.schedule_number == sn],
|
|
||||||
)
|
|
||||||
for sn in schedule_numbers
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# Functions
|
|
||||||
def services_date_and_tiploc(
|
|
||||||
date: SixDate,
|
|
||||||
tiploc: str,
|
|
||||||
tt: Timetable | None = None,
|
|
||||||
):
|
|
||||||
tt = tt if tt is not None else Timetable()
|
|
||||||
on_date: Iterable[int] = tt.execute(
|
|
||||||
select(BS.schedule_number).where(
|
|
||||||
(BS.date_runs_from <= date)
|
|
||||||
& (BS.date_runs_to >= date)
|
|
||||||
& (BS.days_run.like(date.weekday_like))
|
|
||||||
)
|
|
||||||
).schedule_number.values
|
|
||||||
origin: Iterable[int] = tt.execute(
|
|
||||||
select(LO.schedule_number).where(LO.location == f"{tiploc:<8}")
|
|
||||||
).schedule_number.values
|
|
||||||
en_route: Iterable[int] = tt.execute(
|
|
||||||
select(LI.schedule_number).where(
|
|
||||||
(LI.location == f"{tiploc:<8}")
|
|
||||||
& (LI.scheduled_departure_time.not_like(" %"))
|
|
||||||
)
|
|
||||||
).schedule_number.values
|
|
||||||
destination: Iterable[int] = tt.execute(
|
|
||||||
select(LT.schedule_number).where(LT.location == f"{tiploc:<8}")
|
|
||||||
).schedule_number.values
|
|
||||||
sns = np.unique([*origin, *en_route, *destination])
|
|
||||||
sns = np.intersect1d(sns, on_date) # pyright: ignore[reportCallIssue, reportUnknownVariableType, reportArgumentType]
|
|
||||||
sns = [int(sn) for sn in sns] # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
|
|
||||||
return [
|
|
||||||
Service(
|
|
||||||
date=date,
|
|
||||||
**schedule.__dict__, # pyright: ignore[reportAny]
|
|
||||||
)
|
|
||||||
for _, schedule in tt.fetch_schedules(*sns).items()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Script
|
|
||||||
def main():
|
|
||||||
print(s := services_date_and_tiploc(SixDate("260525"), "YORK"))
|
|
||||||
return len(s)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(main())
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -9,17 +9,13 @@ Aimed primarily towards producing a reduced sqlite database.
|
|||||||
# pyright: reportUnknownLambdaType=false
|
# pyright: reportUnknownLambdaType=false
|
||||||
|
|
||||||
# Imports
|
# Imports
|
||||||
import os
|
|
||||||
import sqlite3
|
|
||||||
from itertools import pairwise
|
from itertools import pairwise
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from zipfile import ZipFile
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pypdf import PageObject, PdfReader
|
from pypdf import PageObject, PdfReader
|
||||||
|
|
||||||
from national_rail_timetable.nr_requests import fetch_nr_timetable_files
|
|
||||||
|
|
||||||
# Init.
|
# Init.
|
||||||
SPECIFICATION_TABLE_LOCATIONS = {
|
SPECIFICATION_TABLE_LOCATIONS = {
|
||||||
@@ -37,7 +33,6 @@ SPECIFICATION_TABLE_LOCATIONS = {
|
|||||||
"MCA_ZZ": (23, 1),
|
"MCA_ZZ": (23, 1),
|
||||||
}
|
}
|
||||||
DEFAULT_RAW_SPEC_DATA_DIR = Path(__file__).parents[2] / "data/specification_tables"
|
DEFAULT_RAW_SPEC_DATA_DIR = Path(__file__).parents[2] / "data/specification_tables"
|
||||||
DEFAULT_DB_PATH = Path.home() / ".cache/nr_data/timetable.db"
|
|
||||||
|
|
||||||
|
|
||||||
# Functions
|
# Functions
|
||||||
@@ -151,165 +146,3 @@ def read_specification_table_raws(
|
|||||||
continue
|
continue
|
||||||
tables[path.name[:-4]] = pd.read_csv(path)
|
tables[path.name[:-4]] = pd.read_csv(path)
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
|
|
||||||
def validate_db_path(db_path: Path | None = None):
|
|
||||||
db_path = (
|
|
||||||
db_path
|
|
||||||
if db_path is not None
|
|
||||||
else Path(os.environ.get("NR_DATADIR", DEFAULT_DB_PATH))
|
|
||||||
)
|
|
||||||
db_path.parent.mkdir(exist_ok=True, parents=True)
|
|
||||||
return db_path
|
|
||||||
|
|
||||||
|
|
||||||
def create_mca_specification_dbtables(
|
|
||||||
tables: dict[str, pd.DataFrame],
|
|
||||||
db_path: Path | None = None,
|
|
||||||
):
|
|
||||||
db_path = validate_db_path(db_path)
|
|
||||||
connection = sqlite3.connect(db_path)
|
|
||||||
cursor = connection.cursor()
|
|
||||||
for name, df in tables.items():
|
|
||||||
if (_n := name.split("_"))[0] != "MCA" or len(_n) != 2:
|
|
||||||
continue
|
|
||||||
df["Start Index"] = df["Position"].apply(lambda s: int(s.split("-")[0]) - 1)
|
|
||||||
df["End Index"] = df["Position"].apply(lambda s: int(s.split("-")[-1]))
|
|
||||||
|
|
||||||
_ = cursor.execute(f"DROP TABLE IF EXISTS spec_{name.lower()}")
|
|
||||||
_ = cursor.execute(
|
|
||||||
f"""
|
|
||||||
CREATE TABLE spec_{name.lower()}
|
|
||||||
({", ".join([col.lower().replace(" ", "_") for col in df.columns])})
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
_ = cursor.executemany(
|
|
||||||
f"""
|
|
||||||
INSERT INTO spec_{name.lower()}
|
|
||||||
VALUES({", ".join(["?" for _ in df.columns])})
|
|
||||||
""",
|
|
||||||
[tuple(row.values) for _, row in df.iterrows()],
|
|
||||||
)
|
|
||||||
connection.commit()
|
|
||||||
connection.close()
|
|
||||||
return db_path
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: There is no need for this to take minutes, investigate speed ups.
|
|
||||||
def create_mca_raw_dbtables(
|
|
||||||
zipfile: ZipFile | None = None,
|
|
||||||
db_path: Path | None = None,
|
|
||||||
allow_fetch: bool = True,
|
|
||||||
print_progress: bool = True,
|
|
||||||
) -> dict[str, str]:
|
|
||||||
db_path = validate_db_path(db_path)
|
|
||||||
if zipfile is None:
|
|
||||||
if allow_fetch:
|
|
||||||
zipfile = fetch_nr_timetable_files()
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"There was no zipfile provided and allow_fetch is set to False. "
|
|
||||||
+ "Please either allow automatic fetching, or supply the zipfile argument. "
|
|
||||||
+ "The package's fetching function is fetch_nr_timetable_files. "
|
|
||||||
)
|
|
||||||
|
|
||||||
connection = sqlite3.connect(db_path)
|
|
||||||
cursor = connection.cursor()
|
|
||||||
tables = [
|
|
||||||
row[0]
|
|
||||||
for row in cursor.execute(
|
|
||||||
"SELECT name FROM sqlite_master WHERE type = 'table'"
|
|
||||||
).fetchall()
|
|
||||||
if row[0][:9] == "spec_mca_"
|
|
||||||
]
|
|
||||||
if len(tables) == 0:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
"No spec_mca_... tables found in given database. "
|
|
||||||
+ "Please ensure create_mca_specification_dbtables has been run successfully. "
|
|
||||||
)
|
|
||||||
mappings = {}
|
|
||||||
all_start_indexes = {}
|
|
||||||
all_end_indexes = {}
|
|
||||||
for name in tables:
|
|
||||||
mappings[name] = "raw_" + name[5:]
|
|
||||||
spec = pd.DataFrame(
|
|
||||||
cursor.execute(f"SELECT * FROM {name}").fetchall(),
|
|
||||||
columns=[d[0] for d in cursor.description],
|
|
||||||
)
|
|
||||||
all_start_indexes[mappings[name]] = spec.start_index.values
|
|
||||||
all_end_indexes[mappings[name]] = spec.end_index.values
|
|
||||||
new_columns = [
|
|
||||||
col.split("/")[0].lower().replace(" ", "_").replace("-", "_")
|
|
||||||
for col in spec.field_description
|
|
||||||
] + ["line_number", "schedule_number"]
|
|
||||||
_ = cursor.execute(f"DROP TABLE IF EXISTS {mappings[name].lower()}")
|
|
||||||
_ = cursor.execute(
|
|
||||||
f"""
|
|
||||||
CREATE TABLE {mappings[name]}
|
|
||||||
({(", ".join(new_columns))})
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
connection.commit()
|
|
||||||
|
|
||||||
zipinfo = [
|
|
||||||
zipinfo
|
|
||||||
for name, zipinfo in zipfile.NameToInfo.items()
|
|
||||||
if name.split(".")[-1] == "MCA"
|
|
||||||
][0]
|
|
||||||
file = zipfile.open(zipinfo)
|
|
||||||
schedule_number, line_number = -1, -1
|
|
||||||
while (line := file.readline().decode()) != "":
|
|
||||||
line_number += 1
|
|
||||||
record_type = line[:2]
|
|
||||||
schedule_number += int(record_type == "BS")
|
|
||||||
target_table = f"raw_mca_{record_type.lower()}"
|
|
||||||
start_indexes = all_start_indexes.get(target_table)
|
|
||||||
end_indexes = all_end_indexes.get(target_table)
|
|
||||||
if start_indexes is None or end_indexes is None:
|
|
||||||
continue
|
|
||||||
values = ", ".join(
|
|
||||||
[
|
|
||||||
"'" + line[lb:ub].replace("'", "") + "'"
|
|
||||||
for lb, ub in zip(start_indexes, end_indexes, strict=True)
|
|
||||||
]
|
|
||||||
+ [f"{line_number}", f"{schedule_number}"]
|
|
||||||
)
|
|
||||||
_ = cursor.execute(f"INSERT INTO {target_table} VALUES({values})")
|
|
||||||
if line_number % 3737 == 0 and print_progress:
|
|
||||||
print(f" {line_number:9,} {line[:-1]}", end="\r")
|
|
||||||
if print_progress:
|
|
||||||
print()
|
|
||||||
connection.commit()
|
|
||||||
connection.close()
|
|
||||||
return mappings
|
|
||||||
|
|
||||||
|
|
||||||
# Script
|
|
||||||
def main(
|
|
||||||
skip_pdf: bool = False,
|
|
||||||
pdf_spec_path: Path | None = None,
|
|
||||||
raw_spec_dir: Path | None = None,
|
|
||||||
):
|
|
||||||
if not skip_pdf:
|
|
||||||
try:
|
|
||||||
tables = extract_specification_document_tables(pdf_spec_path)
|
|
||||||
_ = store_specification_table_raws(tables, raw_spec_dir)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
tables = read_specification_table_raws(raw_spec_dir)
|
|
||||||
except FileNotFoundError:
|
|
||||||
raise FileNotFoundError(
|
|
||||||
"The tables generated from the RSP's specification were not found. "
|
|
||||||
+ "This means neither the cached version nor the original .pdf is available. "
|
|
||||||
+ "Try suppling either to their default locations, or supplying custom directories. "
|
|
||||||
+ "Manual fix: extract_specification_document_tables then store_specification_table_raws. "
|
|
||||||
)
|
|
||||||
|
|
||||||
_ = create_mca_specification_dbtables(tables)
|
|
||||||
return create_mca_raw_dbtables()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
_ = main()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user