diff --git a/src/fares_site/api_calling.py b/src/fares_site/api_calling.py index 7cf6b3f..0719e90 100644 --- a/src/fares_site/api_calling.py +++ b/src/fares_site/api_calling.py @@ -18,7 +18,7 @@ def fares_query( toc: str | None = None, ) -> list[dict[str, dict[str, str]]]: url = f"https://fares.ballast-data.co.uk/fares?origin={origin}&destination={destination}" - url += f"&toc={toc}" if toc is not None else "" + url += f"&toc={toc.upper()}" if toc is not None else "" response = requests.get( url=url, auth=( diff --git a/src/fares_site/serve.py b/src/fares_site/serve.py index 3adfdc9..ee630cd 100644 --- a/src/fares_site/serve.py +++ b/src/fares_site/serve.py @@ -35,6 +35,8 @@ COLUMN_RENAMES = { "restriction_code": "Restriction", } +VALID_QUERY_TERMS = ["origin", "destination", "toc"] + DEFAULT_TEMPLATE_PATH = Path(__file__).parents[2] / "media/template.html" logger: Logger = Logger(name=__name__, level=INFO) @@ -46,10 +48,10 @@ class FaresHandler(BaseHTTPRequestHandler): text = self.requestline.split(" ")[1].split("?") if len(text) == 1: return {} - text = text[1] return { - (_s := s.split("="))[0]: _s[-1] if _s[-1] != "" else None - for s in text.split("&") + _s[0]: _s[-1] if _s[-1] != "" else None + for s in text[1].split("&") + if (_s := s.split("="))[0] in VALID_QUERY_TERMS } def content_of_GET(self, template_path: Path = DEFAULT_TEMPLATE_PATH) -> str: @@ -76,6 +78,7 @@ class FaresHandler(BaseHTTPRequestHandler): text = "No Fares Found." with open(template_path, "r") as rf: return rf.read().replace("", text) + text = '
| Flow | Fares |
|---|