diff --git a/src/national_rail_timetable/__main__.py b/src/national_rail_timetable/__main__.py index 0b79a3f..2957be6 100644 --- a/src/national_rail_timetable/__main__.py +++ b/src/national_rail_timetable/__main__.py @@ -1,10 +1,3 @@ -from national_rail_timetable.nr_requests import fetch_nr_token, fetch_nr_timetable_files -from national_rail_timetable.parsing import ( - extract_specification_document_tables, - store_specification_table_raws, - read_specification_table_raws, - create_mca_specification_dbschema, - main, -) +from national_rail_timetable.parsing import main print(main()) diff --git a/src/national_rail_timetable/nr_requests.py b/src/national_rail_timetable/nr_requests.py index 6756829..6bf6efb 100644 --- a/src/national_rail_timetable/nr_requests.py +++ b/src/national_rail_timetable/nr_requests.py @@ -65,7 +65,7 @@ def fetch_nr_token( def fetch_nr_timetable_files( config: NRConfig | None = None, # pyright: ignore[reportRedeclaration] token: str | None = None, # pyright: ignore[reportRedeclaration] - attempts: int = 1, + attempts: int = 3, ) -> ZipFile: config: NRConfig = config if config is not None else NRConfig.from_env() token: str = ( diff --git a/src/national_rail_timetable/parsing.py b/src/national_rail_timetable/parsing.py index f7007a3..811b35c 100644 --- a/src/national_rail_timetable/parsing.py +++ b/src/national_rail_timetable/parsing.py @@ -9,14 +9,17 @@ Aimed primarily towards producing a reduced sqlite database. # pyright: reportUnknownLambdaType=false # Imports +import os +import sqlite3 from itertools import pairwise from pathlib import Path -import pandas as pd +from zipfile import ZipFile + import numpy as np -import sqlite3 -import os +import pandas as pd from pypdf import PageObject, PdfReader +from national_rail_timetable.nr_requests import fetch_nr_timetable_files # Init. SPECIFICATION_TABLE_LOCATIONS = { @@ -34,6 +37,7 @@ SPECIFICATION_TABLE_LOCATIONS = { "MCA_ZZ": (23, 1), } DEFAULT_RAW_SPEC_DATA_DIR = Path(__file__).parents[2] / "data/specification_tables" +DEFAULT_DB_PATH = Path.home() / ".cache/nr_data/timetable.db" # Functions @@ -149,16 +153,21 @@ def read_specification_table_raws( return tables -def create_mca_specification_dbschema( - tables: dict[str, pd.DataFrame], - db_path: Path | None = None, -): +def _validate_db_path(db_path: Path | None = None): db_path = ( db_path if db_path is not None - else Path(os.environ.get("NR_DATADIR", "~/.cache/nr_data/timetable.db")) + else Path(os.environ.get("NR_DATADIR", DEFAULT_DB_PATH)) ) db_path.parent.mkdir(exist_ok=True, parents=True) + return db_path + + +def create_mca_specification_dbtables( + tables: dict[str, pd.DataFrame], + db_path: Path | None = None, +): + db_path = _validate_db_path(db_path) connection = sqlite3.connect(db_path) cursor = connection.cursor() for name, df in tables.items(): @@ -179,13 +188,98 @@ def create_mca_specification_dbschema( INSERT INTO spec_{name.lower()} VALUES({", ".join(["?" for _ in df.columns])}) """, - [list(row.values) for _, row in df.iterrows()], + [tuple(row.values) for _, row in df.iterrows()], ) connection.commit() connection.close() return db_path +# TODO: There is no need for this to take minutes, investigate speed ups. +def create_mca_raw_dbtables( + zipfile: ZipFile | None = None, + db_path: Path | None = None, + allow_fetch: bool = True, +) -> dict[str, str]: + db_path = _validate_db_path(db_path) + if zipfile is None: + if allow_fetch: + zipfile = fetch_nr_timetable_files() + else: + raise RuntimeError( + "There was no zipfile provided and allow_fetch is set to False. " + + "Please either allow automatic fetching, or supply the zipfile argument. " + + "The package's fetching function is fetch_nr_timetable_files. " + ) + + connection = sqlite3.connect(db_path) + cursor = connection.cursor() + tables = [ + row[0] + for row in cursor.execute( + "SELECT name FROM sqlite_master WHERE type = 'table'" + ).fetchall() + if row[0][:9] == "spec_mca_" + ] + if len(tables) == 0: + raise FileNotFoundError( + "No spec_mca_... tables found in given database. " + + "Please ensure create_mca_specification_dbtables has been run successfully. " + ) + mappings = {} + all_start_indexes = {} + all_end_indexes = {} + for name in tables: + mappings[name] = "raw_" + name[5:] + spec = pd.DataFrame( + cursor.execute(f"SELECT * FROM {name}").fetchall(), + columns=[d[0] for d in cursor.description], + ) + all_start_indexes[mappings[name]] = spec.start_index.values + all_end_indexes[mappings[name]] = spec.end_index.values + new_columns = [ + col.split("/")[0].lower().replace(" ", "_").replace("-", "_") + for col in spec.field_description + ] + _ = cursor.execute(f"DROP TABLE IF EXISTS {mappings[name].lower()}") + _ = cursor.execute( + f""" + CREATE TABLE {mappings[name]} + ({(", ".join(new_columns))}) + """, + ) + connection.commit() + + zipinfo = [ + zipinfo + for name, zipinfo in zipfile.NameToInfo.items() + if name.split(".")[-1] == "MCA" + ][0] + file = zipfile.open(zipinfo) + counter = 0 + while (line := file.readline().decode()) != "": + record_type = line[:2] + target_table = f"raw_mca_{record_type.lower()}" + start_indexes = all_start_indexes.get(target_table) + end_indexes = all_end_indexes.get(target_table) + if start_indexes is None or end_indexes is None: + continue + values = ", ".join( + [ + "'" + line[lb:ub].replace("'", "") + "'" + for lb, ub in zip(start_indexes, end_indexes, strict=True) + ] + ) + _ = cursor.execute(f"INSERT INTO {target_table} VALUES({values})") + counter += 1 + if counter % 1111 == 0: + print(f" {counter:,}", end="\r") + print() + connection.commit() + connection.close() + return mappings + + # Script def main( skip_pdf: bool = False, @@ -209,8 +303,9 @@ def main( + "Manual fix: extract_specification_document_tables then store_specification_table_raws. " ) - _ = create_mca_specification_dbschema(tables) + _ = create_mca_specification_dbtables(tables) + return create_mca_raw_dbtables() if __name__ == "__main__": - main() + _ = main()