This commit is contained in:
Amazed 2022-09-03 19:53:08 +02:00
parent 3a9d88fad1
commit b7d85f0763
5 changed files with 421 additions and 32 deletions

28
main.py Normal file
View File

@ -0,0 +1,28 @@
import asyncio
from sncf.sncf import SNCF, PhysicalModes
async def main():
sncf = SNCF("5722269b-2e58-49ee-986b-21741438d5ff") # instanciate SNCF client with auth token
sncf.set_physical_mode(PhysicalModes.Train) # set prefered mode (others can be Autocar, LongDistanceTrain (TGV), ...)
for stop in await sncf.get_station("Benfeld"):
print(stop.name, stop.id)
print("Available physical modes: ", [mode.value for mode in stop.physical_modes])
print("----------------")
sncf.set_source_stop_point("stop_point:SNCF:87214056") # Sélestat
sncf.set_dest_stop_point("stop_point:SNCF:87212027") # Strasbourg
sncf.max_transfers = 2
sncf.max_duration_secs = 20 * 60
next_journeys = await sncf.get_next_journeys(max_count=5)
for next_journey in next_journeys:
print("(D)", next_journey.departure_date_time, "-", "(A)", next_journey.arrival_date_time,
"- Transfers:", next_journey.nb_transfers,
"- Duration:", next_journey.duration, "seconds"
)
if __name__ == '__main__':
asyncio.run(main())

2
requirements.txt Normal file
View File

@ -0,0 +1,2 @@
httpx
pydantic

View File

@ -1,3 +1,5 @@
from .sncf import SNCF, SNCFJourney from .sncf import SNCF
from .objects import Journey
from .objects import *
__all__ = ['SNCF', 'SNCFJourney'] #__all__ = ['SNCF']

343
sncf/objects.py Normal file
View File

@ -0,0 +1,343 @@
import enum
from pydantic import BaseModel, validator, Field
import datetime
from typing import List, Optional
BaseModel.Config.extra = "forbid"
class Coord(BaseModel):
lat: float
lon: float
class Code(BaseModel):
type: str
value: str
class CO2Emission(BaseModel):
value: float
unit: str
class Link(BaseModel):
internal: Optional[bool]
type: str
id: Optional[str]
rel: Optional[str]
templated: Optional[bool]
href: Optional[str]
class AdministrativeRegion(BaseModel):
id: str
name: str
label: str
coord: Coord
level: int
zip_code: str
insee: int
class TravelMode(BaseModel):
id: str
name: str
@property
def value(self) -> str:
return self.id.split(":")[-1]
class PhysicalMode(TravelMode):
@validator("id")
def validate_physicalmode(cls, v):
for physicalmode in PhysicalModes:
if v == f"physical_mode:{physicalmode}":
return v
raise ValueError(v)
class PhysicalModes(str, enum.Enum):
Air = "Air"
Boat = "Boat"
Bus = "Bus"
BusRapidTransit = "BusRapidTransit"
Coach = "Coach"
Ferry = "Ferry"
Funicular = "Funicular"
LocalTrain = "LocalTrain"
LongDistanceTrain = "LongDistanceTrain"
Metro = "Metro"
RailShuttle = "RailShuttle"
RapidTransit = "RapidTransit"
Shuttle = "Shuttle"
SuspendedCableCar = "SuspendedCableCar"
Taxi = "Taxi"
Train = "Train"
Tramway = "Tramway"
class StopArea(BaseModel):
id: str
name: str
label: str
coord: Coord
administrative_regions: Optional[List[AdministrativeRegion]]
stop_points: Optional[List["StopPoint"]]
codes: List[Code]
links: List[Link]
timezone: str
commercial_modes: Optional[List[TravelMode]]
physical_modes: Optional[List[PhysicalMode]]
class StopPoint(BaseModel):
id: str
name: str
coord: Coord
administrative_regions: Optional[List[AdministrativeRegion]]
equipments: List[str]
stop_area: Optional[StopArea]
links: List[Link]
label: str
class JourneyStatus(str, enum.Enum):
NONE = ""
NO_SERVICE = "NO_SERVICE"
REDUCED_SERVICE = "REDUCED_SERVICE"
SIGNIFICANT_DELAYS = "SIGNIFICANT_DELAYS"
DETOUR = "DETOUR"
ADDITIONAL_SERVICE = "ADDITIONAL_SERVICE"
MODIFIED_SERVICE = "MODIFIED_SERVICE"
OTHER_EFFECT = "OTHER_EFFECT"
UNKNOWN_EFFECT = "UNKNOWN_EFFECT"
STOP_MOVED = "STOP_MOVED"
class SectionType(str, enum.Enum):
public_transport = "public_transport"
street_network = "street_network"
waiting = "waiting"
stay_in = "stay_in"
transfer = "transfer"
crow_fly = "crow_fly"
on_demand_transport = "on_demand_transport"
bss_rent = "bss_rent"
bss_put_back = "bss_put_back"
boarding = "boarding"
landing = "landing"
alighting = "alighting"
park = "park"
ridesharing = "ridesharing"
class SectionMode(str, enum.Enum):
NONE = ""
Walking = "walking"
Bike = "bike"
Car = "car"
Taxi = "taxi"
class EmbeddedType(str, enum.Enum):
administrative_region = "administrative_region"
stop_area = "stop_area"
stop_point = "stop_point"
address = "address"
poi = "poi"
class POIType(BaseModel):
id: str
name: str
class POI(BaseModel):
id: str
name: str
label: str
type: POIType
# stands: todo
class Address(BaseModel):
id: str
name: str
label: str
coord: Coord
house_number: int
administrative_regions: List[AdministrativeRegion]
class Place(BaseModel):
id: str
name: str
quality: int
embedded_type: EmbeddedType
administrative_region: Optional[AdministrativeRegion]
stop_area: Optional[StopArea]
stop_point: Optional[StopPoint]
address: Optional[Address]
poi: Optional[POI]
@property
def embed(self) -> AdministrativeRegion | StopArea | POI | Address | StopPoint:
return getattr(self, self.embedded_type)
class DisplayInformation(BaseModel):
network: str
physical_mode: str
commercial_mode: str
code: str
color: str
text_color: str
direction: str
headsign: str
label: str
name: str
trip_short_name: str
equipments: List[str]
description: str
links: List[Link]
class Path(BaseModel):
length: int
name: str
duration: int
direction: int
class TransferType(str, enum.Enum):
walking = "walking"
stay_in = "stay_in"
class StopDateTime(BaseModel):
stop_point: StopPoint
additional_informations: List[str]
base_departure_date_time: Optional[datetime.datetime]
departure_date_time: Optional[datetime.datetime]
base_arrival_date_time: Optional[datetime.datetime]
arrival_date_time: Optional[datetime.datetime]
links: List[str]
@validator(
"base_departure_date_time",
"departure_date_time",
"base_arrival_date_time",
"arrival_date_time",
pre=True, allow_reuse=True)
def convert_datetime(cls, v):
return datetime.datetime.strptime(v, "%Y%m%dT%H%M%S")
class DataFreshness(str, enum.Enum):
realtime = "realtime"
base_schedule = "base_schedule"
class Section(BaseModel):
type: SectionType
id: str
mode: Optional[SectionMode]
duration: int
from_place: Optional[Place] = Field(alias='from')
to_place: Optional[Place] = Field(alias='to')
links: List[Link]
display_informations: Optional[DisplayInformation]
additional_informations: Optional[List[str]] # todo: enum
geojson: Optional[dict] # todo
path: Optional[List[Path]]
transfer_type: Optional[TransferType]
stop_date_times: Optional[List[StopDateTime]]
departure_date_time: datetime.datetime
arrival_date_time: datetime.datetime
co2_emission: CO2Emission
data_freshness: Optional[DataFreshness]
base_departure_date_time: Optional[datetime.datetime]
base_arrival_date_time: Optional[datetime.datetime]
@validator(
"departure_date_time",
"arrival_date_time",
"base_departure_date_time",
"base_arrival_date_time",
pre=True, allow_reuse=True)
def convert_datetime(cls, v):
return datetime.datetime.strptime(v, "%Y%m%dT%H%M%S")
class Cost(BaseModel):
value: float
currency: Optional[str]
class Fare(BaseModel):
total: Cost
found: bool
links: List[Link]
class Durations(BaseModel):
taxi: int
walking: int
car: int
ridesharing: int
bike: int
total: Optional[int]
class Period(BaseModel):
begin: datetime.date
end: datetime.date
@validator(
"begin",
"end",
pre=True, allow_reuse=True)
def convert_datetime(cls, v):
return datetime.datetime.strptime(v, "%Y%m%d")
class WeekPattern(BaseModel):
monday: bool
tuesday: bool
wednesday: bool
thursday: bool
friday: bool
saturday: bool
sunday: bool
class Calendar(BaseModel):
active_periods: List[Period]
week_pattern: WeekPattern
class Journey(BaseModel):
duration: int
nb_transfers: int
departure_date_time: datetime.datetime
requested_date_time: datetime.datetime
arrival_date_time: datetime.datetime
sections: List[Section]
links: List[Link]
type: str # todo enum
fare: Fare
tags: List[str]
status: JourneyStatus
from_place: Place = None
to: Place = None
durations: Durations
distances: Durations
co2_emission: CO2Emission
calendars: List[Calendar]
@validator(
"departure_date_time",
"requested_date_time",
"arrival_date_time",
pre=True, allow_reuse=True)
def convert_datetime(cls, v):
return datetime.datetime.strptime(v, "%Y%m%dT%H%M%S")

View File

@ -1,15 +1,9 @@
import datetime import httpx
import requests from .objects import *
class SNCFJourney(dict): class SNCFException(Exception):
def __repr__(self): pass
sections = []
for section in self["sections"]:
nice_dep_date = SNCF.api_date_to_datetime(section['departure_date_time']).strftime("%H:%M")
nice_arr_date = SNCF.api_date_to_datetime(section['arrival_date_time']).strftime("%H:%M")
sections.append(f"{nice_dep_date} {section['from']['name']} - {section['to']['name']} {nice_arr_date}")
return "\n".join(sections)
class SNCF: class SNCF:
@ -23,6 +17,8 @@ class SNCF:
self.max_transfers: int = 0 self.max_transfers: int = 0
self.max_duration_secs: int = 1800 self.max_duration_secs: int = 1800
self.prefered_date_format = "%H:%M" self.prefered_date_format = "%H:%M"
self.token_validated: bool = False
self.physical_mode: Optional[PhysicalMode] = None
def set_source_stop_point(self, stop_point: str): def set_source_stop_point(self, stop_point: str):
self.source_stop_point = stop_point self.source_stop_point = stop_point
@ -33,19 +29,27 @@ class SNCF:
def set_allowed_lines(self, allowed_lines: list[str]): def set_allowed_lines(self, allowed_lines: list[str]):
self.allowed_lines = allowed_lines self.allowed_lines = allowed_lines
def get_next_journeys(self, count: int = 1): async def get_next_journeys(self, max_count: int = 1):
valid_journeys = [] if not self.physical_mode:
api_response = self.api_request(f"/journeys?from={self.source_stop_point}&to={self.dest_stop_point}&datetime={self.date_now()}&count={count}") raise SNCFException("Use set_physical_mode before using this function")
valid_journeys: List[Journey] = []
api_response = await self.api_request(f"/journeys?from={self.source_stop_point}:{self.physical_mode}&to={self.dest_stop_point}:{self.physical_mode}&datetime={self.date_now()}&count={max_count}&disable_geojson=true")
journeys = api_response["journeys"] journeys = api_response["journeys"]
for journey in journeys: for journey in journeys:
if journey["nb_transfers"] <= self.max_transfers and journey["duration"] <= self.max_duration_secs: if (self.max_transfers == -1 or journey["nb_transfers"] <= self.max_transfers) and (self.max_duration_secs == -1 or journey["duration"] <= self.max_duration_secs):
valid_journeys.append(SNCFJourney(journey)) valid_journeys.append(Journey(**journey))
return valid_journeys return valid_journeys
def api_request(self, url): async def api_request(self, url: str, **kwargs):
req = requests.get(SNCF.BASE_URL + url, headers={"Authorization": self.token}) async with httpx.AsyncClient() as client:
if not kwargs.get("no_check", False) and not self.token_validated:
if not await self.test_api_key():
raise Exception("Token is invalid")
self.token_validated = True
req = await client.get(SNCF.BASE_URL + url, headers={"Authorization": self.token})
req.raise_for_status() req.raise_for_status()
return req.json() json = req.json()
return json
@staticmethod @staticmethod
def date_now() -> str: def date_now() -> str:
@ -55,19 +59,29 @@ class SNCF:
def api_date_to_datetime(api_date: str) -> datetime: def api_date_to_datetime(api_date: str) -> datetime:
return datetime.datetime.strptime(api_date, "%Y%m%dT%H%M%S") return datetime.datetime.strptime(api_date, "%Y%m%dT%H%M%S")
def test_api_key(self) -> bool: async def test_api_key(self) -> bool:
try: try:
self.api_request("") await self.api_request("", no_check=True)
return True return True
except requests.HTTPError: except httpx.RequestError:
return False return False
def get_cheapest_path(self) -> List[Journey]:
pass
if __name__ == '__main__': def set_physical_mode(self, physical_mode: PhysicalModes):
sncf = SNCF("5722269b-2e58-49ee-986b-21741438d5ff") self.physical_mode = physical_mode
sncf.set_source_stop_point("stop_point:SNCF:87214056:Train")
sncf.set_dest_stop_point("stop_point:SNCF:87212027:Train") async def get_station(self, name: str) -> List[StopArea]:
next_journeys = sncf.get_next_journeys(count=5) api_response = await self.api_request(f"/places?q={name}&disable_geojson=true&type=stop_area&depth=2")
for next_journey in next_journeys: places = [Place(**item) for item in api_response["places"]]
print(next_journey) stops = []
#print(json.dumps(next_schedules, indent=4)) for place in places:
if place.stop_area:
if self.physical_mode:
for physical_mode in place.stop_area.physical_modes:
if self.physical_mode == physical_mode.value:
stops.append(place.stop_area)
else:
stops.append(place.stop_area)
return stops