From b7d85f0763540b8fd9569050e3e2895b29d078aa Mon Sep 17 00:00:00 2001 From: Hipstercat Date: Sat, 3 Sep 2022 19:53:08 +0200 Subject: [PATCH] Update --- main.py | 28 ++++ requirements.txt | 2 + sncf/__init__.py | 6 +- sncf/objects.py | 343 +++++++++++++++++++++++++++++++++++++++++++++++ sncf/sncf.py | 74 +++++----- 5 files changed, 421 insertions(+), 32 deletions(-) create mode 100644 main.py create mode 100644 requirements.txt create mode 100644 sncf/objects.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..9f5f199 --- /dev/null +++ b/main.py @@ -0,0 +1,28 @@ +import asyncio +from sncf.sncf import SNCF, PhysicalModes + + +async def main(): + sncf = SNCF("5722269b-2e58-49ee-986b-21741438d5ff") # instanciate SNCF client with auth token + sncf.set_physical_mode(PhysicalModes.Train) # set prefered mode (others can be Autocar, LongDistanceTrain (TGV), ...) + for stop in await sncf.get_station("Benfeld"): + print(stop.name, stop.id) + print("Available physical modes: ", [mode.value for mode in stop.physical_modes]) + + print("----------------") + + sncf.set_source_stop_point("stop_point:SNCF:87214056") # Sélestat + sncf.set_dest_stop_point("stop_point:SNCF:87212027") # Strasbourg + + sncf.max_transfers = 2 + sncf.max_duration_secs = 20 * 60 + + next_journeys = await sncf.get_next_journeys(max_count=5) + for next_journey in next_journeys: + print("(D)", next_journey.departure_date_time, "-", "(A)", next_journey.arrival_date_time, + "- Transfers:", next_journey.nb_transfers, + "- Duration:", next_journey.duration, "seconds" + ) + +if __name__ == '__main__': + asyncio.run(main()) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..bed8151 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +httpx +pydantic \ No newline at end of file diff --git a/sncf/__init__.py b/sncf/__init__.py index c8ac611..32f6bef 100644 --- a/sncf/__init__.py +++ b/sncf/__init__.py @@ -1,3 +1,5 @@ -from .sncf import SNCF, SNCFJourney +from .sncf import SNCF +from .objects import Journey +from .objects import * -__all__ = ['SNCF', 'SNCFJourney'] \ No newline at end of file +#__all__ = ['SNCF'] \ No newline at end of file diff --git a/sncf/objects.py b/sncf/objects.py new file mode 100644 index 0000000..85aced3 --- /dev/null +++ b/sncf/objects.py @@ -0,0 +1,343 @@ +import enum +from pydantic import BaseModel, validator, Field +import datetime +from typing import List, Optional + +BaseModel.Config.extra = "forbid" + + +class Coord(BaseModel): + lat: float + lon: float + + +class Code(BaseModel): + type: str + value: str + +class CO2Emission(BaseModel): + value: float + unit: str + + +class Link(BaseModel): + internal: Optional[bool] + type: str + id: Optional[str] + rel: Optional[str] + templated: Optional[bool] + href: Optional[str] + + +class AdministrativeRegion(BaseModel): + id: str + name: str + label: str + coord: Coord + level: int + zip_code: str + insee: int + + +class TravelMode(BaseModel): + id: str + name: str + + @property + def value(self) -> str: + return self.id.split(":")[-1] + + +class PhysicalMode(TravelMode): + @validator("id") + def validate_physicalmode(cls, v): + for physicalmode in PhysicalModes: + if v == f"physical_mode:{physicalmode}": + return v + raise ValueError(v) + + +class PhysicalModes(str, enum.Enum): + Air = "Air" + Boat = "Boat" + Bus = "Bus" + BusRapidTransit = "BusRapidTransit" + Coach = "Coach" + Ferry = "Ferry" + Funicular = "Funicular" + LocalTrain = "LocalTrain" + LongDistanceTrain = "LongDistanceTrain" + Metro = "Metro" + RailShuttle = "RailShuttle" + RapidTransit = "RapidTransit" + Shuttle = "Shuttle" + SuspendedCableCar = "SuspendedCableCar" + Taxi = "Taxi" + Train = "Train" + Tramway = "Tramway" + + +class StopArea(BaseModel): + id: str + name: str + label: str + coord: Coord + administrative_regions: Optional[List[AdministrativeRegion]] + stop_points: Optional[List["StopPoint"]] + codes: List[Code] + links: List[Link] + timezone: str + commercial_modes: Optional[List[TravelMode]] + physical_modes: Optional[List[PhysicalMode]] + + +class StopPoint(BaseModel): + id: str + name: str + coord: Coord + administrative_regions: Optional[List[AdministrativeRegion]] + equipments: List[str] + stop_area: Optional[StopArea] + links: List[Link] + label: str + + +class JourneyStatus(str, enum.Enum): + NONE = "" + NO_SERVICE = "NO_SERVICE" + REDUCED_SERVICE = "REDUCED_SERVICE" + SIGNIFICANT_DELAYS = "SIGNIFICANT_DELAYS" + DETOUR = "DETOUR" + ADDITIONAL_SERVICE = "ADDITIONAL_SERVICE" + MODIFIED_SERVICE = "MODIFIED_SERVICE" + OTHER_EFFECT = "OTHER_EFFECT" + UNKNOWN_EFFECT = "UNKNOWN_EFFECT" + STOP_MOVED = "STOP_MOVED" + + +class SectionType(str, enum.Enum): + public_transport = "public_transport" + street_network = "street_network" + waiting = "waiting" + stay_in = "stay_in" + transfer = "transfer" + crow_fly = "crow_fly" + on_demand_transport = "on_demand_transport" + bss_rent = "bss_rent" + bss_put_back = "bss_put_back" + boarding = "boarding" + landing = "landing" + alighting = "alighting" + park = "park" + ridesharing = "ridesharing" + + +class SectionMode(str, enum.Enum): + NONE = "" + Walking = "walking" + Bike = "bike" + Car = "car" + Taxi = "taxi" + + +class EmbeddedType(str, enum.Enum): + administrative_region = "administrative_region" + stop_area = "stop_area" + stop_point = "stop_point" + address = "address" + poi = "poi" + + +class POIType(BaseModel): + id: str + name: str + + +class POI(BaseModel): + id: str + name: str + label: str + type: POIType + # stands: todo + + +class Address(BaseModel): + id: str + name: str + label: str + coord: Coord + house_number: int + administrative_regions: List[AdministrativeRegion] + + +class Place(BaseModel): + id: str + name: str + quality: int + embedded_type: EmbeddedType + administrative_region: Optional[AdministrativeRegion] + stop_area: Optional[StopArea] + stop_point: Optional[StopPoint] + address: Optional[Address] + poi: Optional[POI] + + @property + def embed(self) -> AdministrativeRegion | StopArea | POI | Address | StopPoint: + return getattr(self, self.embedded_type) + + +class DisplayInformation(BaseModel): + network: str + physical_mode: str + commercial_mode: str + code: str + color: str + text_color: str + direction: str + headsign: str + label: str + name: str + trip_short_name: str + equipments: List[str] + description: str + links: List[Link] + + +class Path(BaseModel): + length: int + name: str + duration: int + direction: int + + +class TransferType(str, enum.Enum): + walking = "walking" + stay_in = "stay_in" + + +class StopDateTime(BaseModel): + stop_point: StopPoint + additional_informations: List[str] + base_departure_date_time: Optional[datetime.datetime] + departure_date_time: Optional[datetime.datetime] + base_arrival_date_time: Optional[datetime.datetime] + arrival_date_time: Optional[datetime.datetime] + links: List[str] + + @validator( + "base_departure_date_time", + "departure_date_time", + "base_arrival_date_time", + "arrival_date_time", + pre=True, allow_reuse=True) + def convert_datetime(cls, v): + return datetime.datetime.strptime(v, "%Y%m%dT%H%M%S") + + +class DataFreshness(str, enum.Enum): + realtime = "realtime" + base_schedule = "base_schedule" + + +class Section(BaseModel): + type: SectionType + id: str + mode: Optional[SectionMode] + duration: int + from_place: Optional[Place] = Field(alias='from') + to_place: Optional[Place] = Field(alias='to') + links: List[Link] + display_informations: Optional[DisplayInformation] + additional_informations: Optional[List[str]] # todo: enum + geojson: Optional[dict] # todo + path: Optional[List[Path]] + transfer_type: Optional[TransferType] + stop_date_times: Optional[List[StopDateTime]] + departure_date_time: datetime.datetime + arrival_date_time: datetime.datetime + co2_emission: CO2Emission + data_freshness: Optional[DataFreshness] + base_departure_date_time: Optional[datetime.datetime] + base_arrival_date_time: Optional[datetime.datetime] + + @validator( + "departure_date_time", + "arrival_date_time", + "base_departure_date_time", + "base_arrival_date_time", + pre=True, allow_reuse=True) + def convert_datetime(cls, v): + return datetime.datetime.strptime(v, "%Y%m%dT%H%M%S") + + +class Cost(BaseModel): + value: float + currency: Optional[str] + + +class Fare(BaseModel): + total: Cost + found: bool + links: List[Link] + + +class Durations(BaseModel): + taxi: int + walking: int + car: int + ridesharing: int + bike: int + total: Optional[int] + + +class Period(BaseModel): + begin: datetime.date + end: datetime.date + @validator( + "begin", + "end", + pre=True, allow_reuse=True) + def convert_datetime(cls, v): + return datetime.datetime.strptime(v, "%Y%m%d") + + +class WeekPattern(BaseModel): + monday: bool + tuesday: bool + wednesday: bool + thursday: bool + friday: bool + saturday: bool + sunday: bool + + +class Calendar(BaseModel): + active_periods: List[Period] + week_pattern: WeekPattern + +class Journey(BaseModel): + duration: int + nb_transfers: int + departure_date_time: datetime.datetime + requested_date_time: datetime.datetime + arrival_date_time: datetime.datetime + sections: List[Section] + links: List[Link] + type: str # todo enum + fare: Fare + tags: List[str] + status: JourneyStatus + from_place: Place = None + to: Place = None + durations: Durations + distances: Durations + co2_emission: CO2Emission + calendars: List[Calendar] + + @validator( + "departure_date_time", + "requested_date_time", + "arrival_date_time", + pre=True, allow_reuse=True) + def convert_datetime(cls, v): + return datetime.datetime.strptime(v, "%Y%m%dT%H%M%S") diff --git a/sncf/sncf.py b/sncf/sncf.py index e1f965a..a873d8b 100644 --- a/sncf/sncf.py +++ b/sncf/sncf.py @@ -1,15 +1,9 @@ -import datetime -import requests +import httpx +from .objects import * -class SNCFJourney(dict): - def __repr__(self): - sections = [] - for section in self["sections"]: - nice_dep_date = SNCF.api_date_to_datetime(section['departure_date_time']).strftime("%H:%M") - nice_arr_date = SNCF.api_date_to_datetime(section['arrival_date_time']).strftime("%H:%M") - sections.append(f"{nice_dep_date} {section['from']['name']} - {section['to']['name']} {nice_arr_date}") - return "\n".join(sections) +class SNCFException(Exception): + pass class SNCF: @@ -23,6 +17,8 @@ class SNCF: self.max_transfers: int = 0 self.max_duration_secs: int = 1800 self.prefered_date_format = "%H:%M" + self.token_validated: bool = False + self.physical_mode: Optional[PhysicalMode] = None def set_source_stop_point(self, stop_point: str): self.source_stop_point = stop_point @@ -33,19 +29,27 @@ class SNCF: def set_allowed_lines(self, allowed_lines: list[str]): self.allowed_lines = allowed_lines - def get_next_journeys(self, count: int = 1): - valid_journeys = [] - api_response = self.api_request(f"/journeys?from={self.source_stop_point}&to={self.dest_stop_point}&datetime={self.date_now()}&count={count}") + async def get_next_journeys(self, max_count: int = 1): + if not self.physical_mode: + raise SNCFException("Use set_physical_mode before using this function") + valid_journeys: List[Journey] = [] + api_response = await self.api_request(f"/journeys?from={self.source_stop_point}:{self.physical_mode}&to={self.dest_stop_point}:{self.physical_mode}&datetime={self.date_now()}&count={max_count}&disable_geojson=true") journeys = api_response["journeys"] for journey in journeys: - if journey["nb_transfers"] <= self.max_transfers and journey["duration"] <= self.max_duration_secs: - valid_journeys.append(SNCFJourney(journey)) + if (self.max_transfers == -1 or journey["nb_transfers"] <= self.max_transfers) and (self.max_duration_secs == -1 or journey["duration"] <= self.max_duration_secs): + valid_journeys.append(Journey(**journey)) return valid_journeys - def api_request(self, url): - req = requests.get(SNCF.BASE_URL + url, headers={"Authorization": self.token}) - req.raise_for_status() - return req.json() + async def api_request(self, url: str, **kwargs): + async with httpx.AsyncClient() as client: + if not kwargs.get("no_check", False) and not self.token_validated: + if not await self.test_api_key(): + raise Exception("Token is invalid") + self.token_validated = True + req = await client.get(SNCF.BASE_URL + url, headers={"Authorization": self.token}) + req.raise_for_status() + json = req.json() + return json @staticmethod def date_now() -> str: @@ -55,19 +59,29 @@ class SNCF: def api_date_to_datetime(api_date: str) -> datetime: return datetime.datetime.strptime(api_date, "%Y%m%dT%H%M%S") - def test_api_key(self) -> bool: + async def test_api_key(self) -> bool: try: - self.api_request("") + await self.api_request("", no_check=True) return True - except requests.HTTPError: + except httpx.RequestError: return False + def get_cheapest_path(self) -> List[Journey]: + pass -if __name__ == '__main__': - sncf = SNCF("5722269b-2e58-49ee-986b-21741438d5ff") - sncf.set_source_stop_point("stop_point:SNCF:87214056:Train") - sncf.set_dest_stop_point("stop_point:SNCF:87212027:Train") - next_journeys = sncf.get_next_journeys(count=5) - for next_journey in next_journeys: - print(next_journey) - #print(json.dumps(next_schedules, indent=4)) + def set_physical_mode(self, physical_mode: PhysicalModes): + self.physical_mode = physical_mode + + async def get_station(self, name: str) -> List[StopArea]: + api_response = await self.api_request(f"/places?q={name}&disable_geojson=true&type=stop_area&depth=2") + places = [Place(**item) for item in api_response["places"]] + stops = [] + for place in places: + if place.stop_area: + if self.physical_mode: + for physical_mode in place.stop_area.physical_modes: + if self.physical_mode == physical_mode.value: + stops.append(place.stop_area) + else: + stops.append(place.stop_area) + return stops