hass-sncf/sncf/sncf.py

88 lines
3.4 KiB
Python

import httpx
from .objects import *
class SNCFException(Exception):
pass
class SNCF:
BASE_URL = "https://api.sncf.com/v1/coverage/sncf"
def __init__(self, token: str):
self.token = token
self.source_stop_point: str = ""
self.dest_stop_point: str = ""
self.allowed_lines: list[str] = []
self.max_transfers: int = 0
self.max_duration_secs: int = 1800
self.prefered_date_format = "%H:%M"
self.token_validated: bool = False
self.physical_mode: Optional[PhysicalMode] = None
def set_source_stop_point(self, stop_point: str):
self.source_stop_point = stop_point
def set_dest_stop_point(self, stop_point: str):
self.dest_stop_point = stop_point
def set_allowed_lines(self, allowed_lines: list[str]):
self.allowed_lines = allowed_lines
async def get_next_journeys(self, max_count: int = 1):
if not self.physical_mode:
raise SNCFException("Use set_physical_mode before using this function")
valid_journeys: List[Journey] = []
api_response = await self.api_request(f"/journeys?from={self.source_stop_point}:{self.physical_mode}&to={self.dest_stop_point}:{self.physical_mode}&datetime={self.date_now()}&count={max_count}&disable_geojson=true")
journeys = api_response["journeys"]
for journey in journeys:
if (self.max_transfers == -1 or journey["nb_transfers"] <= self.max_transfers) and (self.max_duration_secs == -1 or journey["duration"] <= self.max_duration_secs):
valid_journeys.append(Journey(**journey))
return valid_journeys
async def api_request(self, url: str, **kwargs):
async with httpx.AsyncClient() as client:
if not kwargs.get("no_check", False) and not self.token_validated:
if not await self.test_api_key():
raise Exception("Token is invalid")
self.token_validated = True
req = await client.get(SNCF.BASE_URL + url, headers={"Authorization": self.token})
req.raise_for_status()
json = req.json()
return json
@staticmethod
def date_now() -> str:
return datetime.datetime.now().strftime("%Y%m%dT%H%M%S")
@staticmethod
def api_date_to_datetime(api_date: str) -> datetime:
return datetime.datetime.strptime(api_date, "%Y%m%dT%H%M%S")
async def test_api_key(self) -> bool:
try:
await self.api_request("", no_check=True)
return True
except httpx.RequestError:
return False
def get_cheapest_path(self) -> List[Journey]:
pass
def set_physical_mode(self, physical_mode: PhysicalModes):
self.physical_mode = physical_mode
async def get_station(self, name: str) -> List[StopArea]:
api_response = await self.api_request(f"/places?q={name}&disable_geojson=true&type=stop_area&depth=2")
places = [Place(**item) for item in api_response["places"]]
stops = []
for place in places:
if place.stop_area:
if self.physical_mode:
for physical_mode in place.stop_area.physical_modes:
if self.physical_mode == physical_mode.value:
stops.append(place.stop_area)
else:
stops.append(place.stop_area)
return stops