88 lines
3.4 KiB
Python
88 lines
3.4 KiB
Python
import httpx
|
|
from .objects import *
|
|
|
|
|
|
class SNCFException(Exception):
|
|
pass
|
|
|
|
|
|
class SNCF:
|
|
BASE_URL = "https://api.sncf.com/v1/coverage/sncf"
|
|
|
|
def __init__(self, token: str):
|
|
self.token = token
|
|
self.source_stop_point: str = ""
|
|
self.dest_stop_point: str = ""
|
|
self.allowed_lines: list[str] = []
|
|
self.max_transfers: int = 0
|
|
self.max_duration_secs: int = 1800
|
|
self.prefered_date_format = "%H:%M"
|
|
self.token_validated: bool = False
|
|
self.physical_mode: Optional[PhysicalMode] = None
|
|
|
|
def set_source_stop_point(self, stop_point: str):
|
|
self.source_stop_point = stop_point
|
|
|
|
def set_dest_stop_point(self, stop_point: str):
|
|
self.dest_stop_point = stop_point
|
|
|
|
def set_allowed_lines(self, allowed_lines: list[str]):
|
|
self.allowed_lines = allowed_lines
|
|
|
|
async def get_next_journeys(self, max_count: int = 1):
|
|
if not self.physical_mode:
|
|
raise SNCFException("Use set_physical_mode before using this function")
|
|
valid_journeys: List[Journey] = []
|
|
api_response = await self.api_request(f"/journeys?from={self.source_stop_point}:{self.physical_mode}&to={self.dest_stop_point}:{self.physical_mode}&datetime={self.date_now()}&count={max_count}&disable_geojson=true")
|
|
journeys = api_response["journeys"]
|
|
for journey in journeys:
|
|
if (self.max_transfers == -1 or journey["nb_transfers"] <= self.max_transfers) and (self.max_duration_secs == -1 or journey["duration"] <= self.max_duration_secs):
|
|
valid_journeys.append(Journey(**journey))
|
|
return valid_journeys
|
|
|
|
async def api_request(self, url: str, **kwargs):
|
|
async with httpx.AsyncClient() as client:
|
|
if not kwargs.get("no_check", False) and not self.token_validated:
|
|
if not await self.test_api_key():
|
|
raise Exception("Token is invalid")
|
|
self.token_validated = True
|
|
req = await client.get(SNCF.BASE_URL + url, headers={"Authorization": self.token})
|
|
req.raise_for_status()
|
|
json = req.json()
|
|
return json
|
|
|
|
@staticmethod
|
|
def date_now() -> str:
|
|
return datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
|
|
|
|
@staticmethod
|
|
def api_date_to_datetime(api_date: str) -> datetime:
|
|
return datetime.datetime.strptime(api_date, "%Y%m%dT%H%M%S")
|
|
|
|
async def test_api_key(self) -> bool:
|
|
try:
|
|
await self.api_request("", no_check=True)
|
|
return True
|
|
except httpx.RequestError:
|
|
return False
|
|
|
|
def get_cheapest_path(self) -> List[Journey]:
|
|
pass
|
|
|
|
def set_physical_mode(self, physical_mode: PhysicalModes):
|
|
self.physical_mode = physical_mode
|
|
|
|
async def get_station(self, name: str) -> List[StopArea]:
|
|
api_response = await self.api_request(f"/places?q={name}&disable_geojson=true&type=stop_area&depth=2")
|
|
places = [Place(**item) for item in api_response["places"]]
|
|
stops = []
|
|
for place in places:
|
|
if place.stop_area:
|
|
if self.physical_mode:
|
|
for physical_mode in place.stop_area.physical_modes:
|
|
if self.physical_mode == physical_mode.value:
|
|
stops.append(place.stop_area)
|
|
else:
|
|
stops.append(place.stop_area)
|
|
return stops
|