diff --git a/setup.py b/setup.py index 56c8d79..a1a4223 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,17 @@ import requests import sys from typing import List import configparser +import mysql.connector HTTP = requests.session() +VRR_STATIONS_TABLE = """ +CREATE TABLE IF NOT EXISTS vrr_stations ( +station_id int not null primary key, +station_name text); + +""" + def yn(s: str) -> bool: return s in ['y', 'Y', 'J', 'j', ''] @@ -24,13 +32,27 @@ def search_station(search: str) -> List or None: return resp.json()['suggestions'] -def get_station() -> int: +def add_stations_to_db(results: list, db: dict) -> None: + cx = mysql.connector.connect(**db) + cr = cx.cursor() + cr.execute(VRR_STATIONS_TABLE) + for r in results: + cr.execute('REPLACE INTO vrr_stations' + '(station_id, station_name)' + 'VALUES (%s, %s)', (r['data'], r['value'])) + cx.commit() + cr.close() + cx.close() + + +def get_station(db: dict) -> int: station_id = None while station_id is None: search = input("Which station would you like to monitor? ") print("Getting suggestions...") results = search_station(search) if results: # empty lists and None are False + add_stations_to_db(results, db) for i, result in enumerate(results): print(str(i) + ". " + result['value'] + "\t" + result['data']) choice_ptr = None @@ -87,8 +109,29 @@ def get_lines(station_id: int) -> List[str]: for r in filt_arr] +def config_db() -> dict: + def _enter_details() -> dict: + r = dict() + r['host'] = input("Please enter the database hostname: ") + r['user'] = input("Please enter the database user: ") + r['password'] = input("Please enter the database password: ") + r['database'] = input("Please enter the database name: ") + return r + successful = False + while not successful: + r = _enter_details() + try: + cx = mysql.connector.connect(**r) + except: + print("The database settings seem incorrect. Please try again.") + else: + successful = True + return r + + def setup() -> None: - station_id = get_station() + db_config = config_db() + station_id = get_station(db_config) lines_ch = input("Would you like to choose specific lines? (Y/n)", ) if yn(lines_ch): lines = get_lines(station_id) @@ -109,10 +152,10 @@ def setup() -> None: cfg['crawl']['use_elevated_trains'] = 'yes' cfg['crawl']['use_lines'] = ",".join(lines) if lines is not None else "" cfg.add_section('db') - cfg['db']['user'] = 'vrr' - cfg['db']['pass'] = 'vrr' - cfg['db']['host'] = 'localhost' - cfg['db']['database'] = 'vrr' + cfg['db']['user'] = db_config['user'] + cfg['db']['pass'] = db_config['password'] + cfg['db']['host'] = db_config['host'] + cfg['db']['database'] = db_config['database'] print("Please save the following output to 'vrr.ini' and adjust any further settings:") print("\n" * 3) cfg.write(sys.stdout)