import httpx from .objects import * class SNCFException(Exception): pass class SNCF: BASE_URL = "https://api.sncf.com/v1/coverage/sncf" def __init__(self, token: str): self.token = token self.source_stop_point: str = "" self.dest_stop_point: str = "" self.allowed_lines: list[str] = [] self.max_transfers: int = 0 self.max_duration_secs: int = 1800 self.prefered_date_format = "%H:%M" self.token_validated: bool = False self.physical_mode: Optional[PhysicalMode] = None def set_source_stop_point(self, stop_point: str): self.source_stop_point = stop_point def set_dest_stop_point(self, stop_point: str): self.dest_stop_point = stop_point def set_allowed_lines(self, allowed_lines: list[str]): self.allowed_lines = allowed_lines async def get_next_journeys(self, max_count: int = 1): if not self.physical_mode: raise SNCFException("Use set_physical_mode before using this function") valid_journeys: List[Journey] = [] api_response = await self.api_request(f"/journeys?from={self.source_stop_point}:{self.physical_mode}&to={self.dest_stop_point}:{self.physical_mode}&datetime={self.date_now()}&count={max_count}&disable_geojson=true") journeys = api_response["journeys"] for journey in journeys: if (self.max_transfers == -1 or journey["nb_transfers"] <= self.max_transfers) and (self.max_duration_secs == -1 or journey["duration"] <= self.max_duration_secs): valid_journeys.append(Journey(**journey)) return valid_journeys async def api_request(self, url: str, **kwargs): async with httpx.AsyncClient() as client: if not kwargs.get("no_check", False) and not self.token_validated: if not await self.test_api_key(): raise Exception("Token is invalid") self.token_validated = True req = await client.get(SNCF.BASE_URL + url, headers={"Authorization": self.token}) req.raise_for_status() json = req.json() return json @staticmethod def date_now() -> str: return datetime.datetime.now().strftime("%Y%m%dT%H%M%S") @staticmethod def api_date_to_datetime(api_date: str) -> datetime: return datetime.datetime.strptime(api_date, "%Y%m%dT%H%M%S") async def test_api_key(self) -> bool: try: await self.api_request("", no_check=True) return True except httpx.RequestError: return False def get_cheapest_path(self) -> List[Journey]: pass def set_physical_mode(self, physical_mode: PhysicalModes): self.physical_mode = physical_mode async def get_station(self, name: str) -> List[StopArea]: api_response = await self.api_request(f"/places?q={name}&disable_geojson=true&type=stop_area&depth=2") places = [Place(**item) for item in api_response["places"]] stops = [] for place in places: if place.stop_area: if self.physical_mode: for physical_mode in place.stop_area.physical_modes: if self.physical_mode == physical_mode.value: stops.append(place.stop_area) else: stops.append(place.stop_area) return stops