Removal of sqlite/sqlalchemy based approach - it is too slow to combine the results, even with in-memory database loading.
This commit is contained in:
+2
-1
@@ -11,7 +11,8 @@ dependencies = [
|
||||
"pypdf (>=6.12.0,<7.0.0)",
|
||||
"pandas (>=3.0.3,<4.0.0)",
|
||||
"pandas-stubs (>=3.0.0.260204,<4.0.0.0)",
|
||||
"sqlalchemy (>=2.0.49,<3.0.0)"
|
||||
"sqlalchemy (>=2.0.49,<3.0.0)",
|
||||
"py-spy (>=0.4.2,<0.5.0)"
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,279 +0,0 @@
|
||||
"""
|
||||
Queries for the 'raw_mca_...' tables generated in parsing.py.
|
||||
Thus far, attempts at few assumptions for record types have been made.
|
||||
These queries will outright expect certain properties of the databases generated.
|
||||
If they suddenly stop working - it should be due to an RSP specification change.
|
||||
Therefore, the error handling will be less graceful as this is not predictable.
|
||||
It seems unlikely that the CIF format will be modified any time soon.
|
||||
"""
|
||||
|
||||
# Imports
|
||||
from collections.abc import Iterable
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Engine,
|
||||
MetaData,
|
||||
Select,
|
||||
Table,
|
||||
create_engine,
|
||||
select,
|
||||
)
|
||||
|
||||
from national_rail_timetable.mca_record_types import (
|
||||
BS,
|
||||
BX,
|
||||
CR,
|
||||
LI,
|
||||
LO,
|
||||
LT,
|
||||
BaseRecord,
|
||||
)
|
||||
from national_rail_timetable.parsing import validate_db_path
|
||||
|
||||
|
||||
# Classes
|
||||
class SixDate(str):
|
||||
def __post_init__(self):
|
||||
assert len(self) == 6
|
||||
assert self.isnumeric()
|
||||
|
||||
def as_unix(self):
|
||||
return f"20{self[:2]}-{self[2:4]}-{self[-2:]}"
|
||||
|
||||
def as_numpy(self):
|
||||
return np.datetime64(self.as_unix())
|
||||
|
||||
def as_datetime(self):
|
||||
return datetime.fromisoformat(self.as_unix())
|
||||
|
||||
@classmethod
|
||||
def from_unix(cls, unix_str: str):
|
||||
return cls(unix_str[2:].replace("-", "").replace("/", ""))
|
||||
|
||||
@classmethod
|
||||
def from_numpy(cls, np_dt: np.datetime64):
|
||||
return cls.from_unix(str(np_dt).split("T")[0])
|
||||
|
||||
@classmethod
|
||||
def from_datetime(cls, dt: datetime):
|
||||
return cls.from_unix(str(dt).split(" ")[0])
|
||||
|
||||
@property
|
||||
def weekday(self) -> int:
|
||||
return self.as_datetime().weekday()
|
||||
|
||||
@property
|
||||
def weekday_like(self) -> str:
|
||||
return "_" * self.weekday + "1%"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Schedule:
|
||||
sn: int
|
||||
bs: pd.Series
|
||||
bx: pd.Series
|
||||
loit: pd.DataFrame
|
||||
cr: pd.DataFrame
|
||||
|
||||
|
||||
@dataclass
|
||||
class Service(Schedule):
|
||||
date: SixDate
|
||||
|
||||
|
||||
@dataclass
|
||||
class Timetable:
|
||||
engine: Engine = field(
|
||||
default_factory=lambda: create_engine(
|
||||
url=f"sqlite:///{validate_db_path().as_posix()}"
|
||||
)
|
||||
)
|
||||
metadata: MetaData = field(default_factory=MetaData)
|
||||
tables: dict[str, Table] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
self._populate_self_tables()
|
||||
|
||||
def _populate_self_tables(self):
|
||||
cursor = (connection := self._generate_sqlite3_connection()).cursor()
|
||||
self.tables |= {
|
||||
name: Table(
|
||||
name, # pyright: ignore[reportAny]
|
||||
self.metadata,
|
||||
*[
|
||||
Column(d[0])
|
||||
for d in cursor.execute(f"SELECT * FROM {name} LIMIT 1").description
|
||||
],
|
||||
)
|
||||
for name in [ # pyright: ignore[reportAny]
|
||||
row[0]
|
||||
for row in cursor.execute( # pyright: ignore[reportAny]
|
||||
"SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||
).fetchall()
|
||||
if row[0][:8] == "raw_mca_"
|
||||
]
|
||||
}
|
||||
connection.close()
|
||||
|
||||
def _generate_sqlite3_connection(self) -> sqlite3.Connection:
|
||||
return sqlite3.connect(self.engine.url.__to_string__().split("///")[1])
|
||||
|
||||
# TODO: Implement docstrings from 'spec_mca_...' tables
|
||||
def _hardcode_table_dataclasses(self):
|
||||
text: str = (
|
||||
"# This file is pre-generated for type-hinting while writing MCA file queries. \n"
|
||||
+ "# Any changes made manually will likely be overwritten. \n"
|
||||
+ "# It should not need to be generated more than once. \n"
|
||||
+ "# If the RSP's timetable specification changes, then this will need to be updated. \n"
|
||||
+ "# Result of mca_queries.py's Timetable._hardcode_table_dataclasses. \n"
|
||||
+ "\n"
|
||||
+ "# Imports \n"
|
||||
+ "from dataclasses import dataclass, field \n"
|
||||
+ "from typing import Any \n"
|
||||
+ "from sqlalchemy import Column, MetaData, Table, String, Integer \n"
|
||||
+ "\n"
|
||||
+ "# Init. \n"
|
||||
+ "metadata = MetaData() \n"
|
||||
+ "\n"
|
||||
+ "# Classes \n"
|
||||
+ "@dataclass \n"
|
||||
+ "class BaseRecord: \n"
|
||||
+ "\tall: Table \n\n"
|
||||
+ "\t@property \n"
|
||||
+ "\tdef schedule_number(self) -> Column[Integer]: ... \n\n"
|
||||
+ "\t@property \n"
|
||||
+ "\tdef line_number(self) -> Column[Integer]: ... \n\n"
|
||||
)
|
||||
for name in self.tables:
|
||||
columns = [column.name for column in self.tables[name].columns]
|
||||
text += (
|
||||
f"_{(rr := name.split('_')[-1].upper())}_columns = Table( \n"
|
||||
+ f"\t'{name}', \n"
|
||||
+ "\tmetadata, \n"
|
||||
+ "".join([f"\tColumn('{column}'), \n" for column in columns])
|
||||
+ ") \n\n"
|
||||
)
|
||||
text += (
|
||||
"@dataclass \n"
|
||||
+ f"class _{rr}_base(BaseRecord): \n"
|
||||
+ "\tall: Table \n\n"
|
||||
+ "".join(
|
||||
[
|
||||
(
|
||||
"\t@property \n"
|
||||
+ f"\tdef {column}(self) -> "
|
||||
+ f"Column[{'String' if column not in ['line_number', 'schedule_number'] else 'Integer'}]: \n"
|
||||
+ f"\t\treturn self.all.c.{column} \n\n"
|
||||
)
|
||||
for column in columns
|
||||
]
|
||||
)
|
||||
)
|
||||
text += f"{rr} = _{rr}_base(_{rr}_columns) \n\n"
|
||||
path = Path(__file__).parent / "mca_record_types.py"
|
||||
with open(path, "w") as wf:
|
||||
_ = wf.write(text.replace("\t", " "))
|
||||
|
||||
@classmethod
|
||||
def from_db_path(cls, db_path: Path | None = None):
|
||||
db_path = validate_db_path(db_path)
|
||||
return cls(engine=create_engine(url=f"sqlite:///{db_path.as_posix()}"))
|
||||
|
||||
@classmethod
|
||||
def default(cls):
|
||||
return cls.from_db_path()
|
||||
|
||||
def execute(self, query: Select[Any]) -> pd.DataFrame: # pyright: ignore[reportExplicitAny]
|
||||
with self.engine.connect() as connection:
|
||||
return pd.read_sql(query, connection)
|
||||
|
||||
def _fetch_records_of_schedules(
|
||||
self,
|
||||
record_type: BaseRecord,
|
||||
*schedule_numbers: int,
|
||||
) -> pd.DataFrame:
|
||||
return self.execute(
|
||||
select(record_type.all).where(
|
||||
record_type.schedule_number.in_(schedule_numbers)
|
||||
)
|
||||
)
|
||||
|
||||
def fetch_schedules(self, *schedule_numbers: int) -> dict[int, Schedule]:
|
||||
bs = self._fetch_records_of_schedules(BS, *schedule_numbers)
|
||||
bx = self._fetch_records_of_schedules(BX, *schedule_numbers)
|
||||
lo = self._fetch_records_of_schedules(LO, *schedule_numbers)
|
||||
li = self._fetch_records_of_schedules(LI, *schedule_numbers)
|
||||
lt = self._fetch_records_of_schedules(LT, *schedule_numbers)
|
||||
cr = self._fetch_records_of_schedules(CR, *schedule_numbers)
|
||||
return {
|
||||
sn: Schedule(
|
||||
sn=sn,
|
||||
bs=bs[bs.schedule_number == sn].iloc[0],
|
||||
bx=bx[bx.schedule_number == sn].iloc[0],
|
||||
loit=pd.concat(
|
||||
[
|
||||
lo[lo.schedule_number == sn],
|
||||
li[li.schedule_number == sn],
|
||||
lt[lt.schedule_number == sn],
|
||||
]
|
||||
).reset_index(drop=True),
|
||||
cr=cr[cr.schedule_number == sn],
|
||||
)
|
||||
for sn in schedule_numbers
|
||||
}
|
||||
|
||||
|
||||
# Functions
|
||||
def services_date_and_tiploc(
|
||||
date: SixDate,
|
||||
tiploc: str,
|
||||
tt: Timetable | None = None,
|
||||
):
|
||||
tt = tt if tt is not None else Timetable()
|
||||
on_date: Iterable[int] = tt.execute(
|
||||
select(BS.schedule_number).where(
|
||||
(BS.date_runs_from <= date)
|
||||
& (BS.date_runs_to >= date)
|
||||
& (BS.days_run.like(date.weekday_like))
|
||||
)
|
||||
).schedule_number.values
|
||||
origin: Iterable[int] = tt.execute(
|
||||
select(LO.schedule_number).where(LO.location == f"{tiploc:<8}")
|
||||
).schedule_number.values
|
||||
en_route: Iterable[int] = tt.execute(
|
||||
select(LI.schedule_number).where(
|
||||
(LI.location == f"{tiploc:<8}")
|
||||
& (LI.scheduled_departure_time.not_like(" %"))
|
||||
)
|
||||
).schedule_number.values
|
||||
destination: Iterable[int] = tt.execute(
|
||||
select(LT.schedule_number).where(LT.location == f"{tiploc:<8}")
|
||||
).schedule_number.values
|
||||
sns = np.unique([*origin, *en_route, *destination])
|
||||
sns = np.intersect1d(sns, on_date) # pyright: ignore[reportCallIssue, reportUnknownVariableType, reportArgumentType]
|
||||
sns = [int(sn) for sn in sns] # pyright: ignore[reportUnknownVariableType, reportUnknownArgumentType]
|
||||
return [
|
||||
Service(
|
||||
date=date,
|
||||
**schedule.__dict__, # pyright: ignore[reportAny]
|
||||
)
|
||||
for _, schedule in tt.fetch_schedules(*sns).items()
|
||||
]
|
||||
|
||||
|
||||
# Script
|
||||
def main():
|
||||
print(s := services_date_and_tiploc(SixDate("260525"), "YORK"))
|
||||
return len(s)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(main())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -9,17 +9,13 @@ Aimed primarily towards producing a reduced sqlite database.
|
||||
# pyright: reportUnknownLambdaType=false
|
||||
|
||||
# Imports
|
||||
import os
|
||||
import sqlite3
|
||||
from itertools import pairwise
|
||||
from pathlib import Path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pypdf import PageObject, PdfReader
|
||||
|
||||
from national_rail_timetable.nr_requests import fetch_nr_timetable_files
|
||||
|
||||
# Init.
|
||||
SPECIFICATION_TABLE_LOCATIONS = {
|
||||
@@ -37,7 +33,6 @@ SPECIFICATION_TABLE_LOCATIONS = {
|
||||
"MCA_ZZ": (23, 1),
|
||||
}
|
||||
DEFAULT_RAW_SPEC_DATA_DIR = Path(__file__).parents[2] / "data/specification_tables"
|
||||
DEFAULT_DB_PATH = Path.home() / ".cache/nr_data/timetable.db"
|
||||
|
||||
|
||||
# Functions
|
||||
@@ -151,165 +146,3 @@ def read_specification_table_raws(
|
||||
continue
|
||||
tables[path.name[:-4]] = pd.read_csv(path)
|
||||
return tables
|
||||
|
||||
|
||||
def validate_db_path(db_path: Path | None = None):
|
||||
db_path = (
|
||||
db_path
|
||||
if db_path is not None
|
||||
else Path(os.environ.get("NR_DATADIR", DEFAULT_DB_PATH))
|
||||
)
|
||||
db_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
return db_path
|
||||
|
||||
|
||||
def create_mca_specification_dbtables(
|
||||
tables: dict[str, pd.DataFrame],
|
||||
db_path: Path | None = None,
|
||||
):
|
||||
db_path = validate_db_path(db_path)
|
||||
connection = sqlite3.connect(db_path)
|
||||
cursor = connection.cursor()
|
||||
for name, df in tables.items():
|
||||
if (_n := name.split("_"))[0] != "MCA" or len(_n) != 2:
|
||||
continue
|
||||
df["Start Index"] = df["Position"].apply(lambda s: int(s.split("-")[0]) - 1)
|
||||
df["End Index"] = df["Position"].apply(lambda s: int(s.split("-")[-1]))
|
||||
|
||||
_ = cursor.execute(f"DROP TABLE IF EXISTS spec_{name.lower()}")
|
||||
_ = cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE spec_{name.lower()}
|
||||
({", ".join([col.lower().replace(" ", "_") for col in df.columns])})
|
||||
""",
|
||||
)
|
||||
_ = cursor.executemany(
|
||||
f"""
|
||||
INSERT INTO spec_{name.lower()}
|
||||
VALUES({", ".join(["?" for _ in df.columns])})
|
||||
""",
|
||||
[tuple(row.values) for _, row in df.iterrows()],
|
||||
)
|
||||
connection.commit()
|
||||
connection.close()
|
||||
return db_path
|
||||
|
||||
|
||||
# TODO: There is no need for this to take minutes, investigate speed ups.
|
||||
def create_mca_raw_dbtables(
|
||||
zipfile: ZipFile | None = None,
|
||||
db_path: Path | None = None,
|
||||
allow_fetch: bool = True,
|
||||
print_progress: bool = True,
|
||||
) -> dict[str, str]:
|
||||
db_path = validate_db_path(db_path)
|
||||
if zipfile is None:
|
||||
if allow_fetch:
|
||||
zipfile = fetch_nr_timetable_files()
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"There was no zipfile provided and allow_fetch is set to False. "
|
||||
+ "Please either allow automatic fetching, or supply the zipfile argument. "
|
||||
+ "The package's fetching function is fetch_nr_timetable_files. "
|
||||
)
|
||||
|
||||
connection = sqlite3.connect(db_path)
|
||||
cursor = connection.cursor()
|
||||
tables = [
|
||||
row[0]
|
||||
for row in cursor.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type = 'table'"
|
||||
).fetchall()
|
||||
if row[0][:9] == "spec_mca_"
|
||||
]
|
||||
if len(tables) == 0:
|
||||
raise FileNotFoundError(
|
||||
"No spec_mca_... tables found in given database. "
|
||||
+ "Please ensure create_mca_specification_dbtables has been run successfully. "
|
||||
)
|
||||
mappings = {}
|
||||
all_start_indexes = {}
|
||||
all_end_indexes = {}
|
||||
for name in tables:
|
||||
mappings[name] = "raw_" + name[5:]
|
||||
spec = pd.DataFrame(
|
||||
cursor.execute(f"SELECT * FROM {name}").fetchall(),
|
||||
columns=[d[0] for d in cursor.description],
|
||||
)
|
||||
all_start_indexes[mappings[name]] = spec.start_index.values
|
||||
all_end_indexes[mappings[name]] = spec.end_index.values
|
||||
new_columns = [
|
||||
col.split("/")[0].lower().replace(" ", "_").replace("-", "_")
|
||||
for col in spec.field_description
|
||||
] + ["line_number", "schedule_number"]
|
||||
_ = cursor.execute(f"DROP TABLE IF EXISTS {mappings[name].lower()}")
|
||||
_ = cursor.execute(
|
||||
f"""
|
||||
CREATE TABLE {mappings[name]}
|
||||
({(", ".join(new_columns))})
|
||||
""",
|
||||
)
|
||||
connection.commit()
|
||||
|
||||
zipinfo = [
|
||||
zipinfo
|
||||
for name, zipinfo in zipfile.NameToInfo.items()
|
||||
if name.split(".")[-1] == "MCA"
|
||||
][0]
|
||||
file = zipfile.open(zipinfo)
|
||||
schedule_number, line_number = -1, -1
|
||||
while (line := file.readline().decode()) != "":
|
||||
line_number += 1
|
||||
record_type = line[:2]
|
||||
schedule_number += int(record_type == "BS")
|
||||
target_table = f"raw_mca_{record_type.lower()}"
|
||||
start_indexes = all_start_indexes.get(target_table)
|
||||
end_indexes = all_end_indexes.get(target_table)
|
||||
if start_indexes is None or end_indexes is None:
|
||||
continue
|
||||
values = ", ".join(
|
||||
[
|
||||
"'" + line[lb:ub].replace("'", "") + "'"
|
||||
for lb, ub in zip(start_indexes, end_indexes, strict=True)
|
||||
]
|
||||
+ [f"{line_number}", f"{schedule_number}"]
|
||||
)
|
||||
_ = cursor.execute(f"INSERT INTO {target_table} VALUES({values})")
|
||||
if line_number % 3737 == 0 and print_progress:
|
||||
print(f" {line_number:9,} {line[:-1]}", end="\r")
|
||||
if print_progress:
|
||||
print()
|
||||
connection.commit()
|
||||
connection.close()
|
||||
return mappings
|
||||
|
||||
|
||||
# Script
|
||||
def main(
|
||||
skip_pdf: bool = False,
|
||||
pdf_spec_path: Path | None = None,
|
||||
raw_spec_dir: Path | None = None,
|
||||
):
|
||||
if not skip_pdf:
|
||||
try:
|
||||
tables = extract_specification_document_tables(pdf_spec_path)
|
||||
_ = store_specification_table_raws(tables, raw_spec_dir)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
tables = read_specification_table_raws(raw_spec_dir)
|
||||
except FileNotFoundError:
|
||||
raise FileNotFoundError(
|
||||
"The tables generated from the RSP's specification were not found. "
|
||||
+ "This means neither the cached version nor the original .pdf is available. "
|
||||
+ "Try suppling either to their default locations, or supplying custom directories. "
|
||||
+ "Manual fix: extract_specification_document_tables then store_specification_table_raws. "
|
||||
)
|
||||
|
||||
_ = create_mca_specification_dbtables(tables)
|
||||
return create_mca_raw_dbtables()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_ = main()
|
||||
|
||||
Reference in New Issue
Block a user