Source code for aio_sf_streaming.core

"""
Core module: define the root base class with main logic
"""
import abc
import asyncio
import copy
import logging
from typing import Any, Dict, List, Optional, Tuple, Union

import aiohttp

logger = logging.getLogger("aio_sf_streaming")

# Typing utils
JSONObject = Dict[str, Any]
JSONList = List[JSONObject]
JSONType = Union[JSONObject, JSONList]


[docs]class BaseSalesforceStreaming(abc.ABC): """ Base low-level *aio-sf-streaming* class. Can not be used directly: must be sub-classed with at least one connector implementation. The class provide basic functionalities. Additional functionalities can be added with provided mixins. The main logic is implemented here but you should not use it directly. :param sandbox: If ``True``, the connexion will be made on a sandbox, from ``https://test.salesforce.com`` instead of the main login route at ``https://login.salesforce.com``. :param version: The API version to use. For example ``'42.0'``. :param loop: Asyncio loop used :param connector: ``aiohttp`` connector used for main session. Mainly used for test purpose. This class supports the context manager protocol for self closing. All main members are coroutine, even if default implementation does do any asynchronous call. With this convention, sub classes and mixins can easily override this members and do complex call. See :py:class:`SimpleSalesforceStreaming` for an usage example. """ version: str #: SF api version to use sandbox: bool #: Use test server connector: aiohttp.BaseConnector #: aiohttp connector for main session instance_url: Optional[str] #: Instance url (retrieved with token) session: Optional[aiohttp.ClientSession] #: Underlying connection client_id: Optional[str] #: The client id token from handshake message_count: int #: Message id count timeout: int #: Timeout connection duration should_stop: bool #: Set to True to stop streaming #: Header used in all requests base_header: dict = {"Accept": "application/json"} def __init__( self, *, sandbox: bool = False, version: str = "42.0", loop: asyncio.AbstractEventLoop = None, connector: aiohttp.BaseConnector = None, ) -> None: self.version = version self.sandbox = sandbox self._loop = loop self.connector = connector self.instance_url = None self.session = None self.client_id = None self.message_count = 0 self.timeout = 120 self.should_stop = False super().__init__() # -------------------- High level api --------------------
[docs] async def start(self) -> None: """ Connect to Salesforce, authenticate and init CometD connexion. A best practice is to use async context manager interface that will call this method directly. """ self.session = await self.create_connected_session() await self.handshake()
[docs] async def subscribe(self, channel: str) -> JSONList: """ Subscribe to a channel. Can be used directly:: await client.subscribe('/topic/Foo') This method, and the underlying protocol, are safe to be started as an background task:: loop.create_task(client.subscribe('/topic/Foo')) """ response = await self.send(await self.get_subscribe_payload(channel)) logger.info("Subscribe response: %r", response) return response
[docs] async def messages(self) -> JSONObject: """ Asynchronous generator that fetch new messages and return one as soon as one is available:: async for message in client.messages(): channel = message['channel'] print(channel, ':', message) This method iterate over **all** messages, even on internal/meta one. If you want to only iterate over messages from channels you subscribed, you should use :py:func:`BaseSalesforceStreaming.events`. .. warning:: Linked to the underlying protocol, long-pooling based, the client should reconnect as soon as possible. Practically, client have 40 seconds to reconnect. If your processing take a longer time, a new connection should be made. You should avoid doing long processing between each iteration or launch this processing into a background task. """ while True: if self.should_stop: return try: response = await self.send( {"channel": "/meta/connect", "connectionType": "long-polling"} ) except asyncio.TimeoutError: logger.info("Timeout") continue except aiohttp.ClientResponseError as error: if error.code == 408: # Timeout logger.info("Timeout") continue else: raise if self.should_stop: return logger.debug("Messages: received %r", response) for message in response: if self.should_stop: return yield message
[docs] async def events(self) -> JSONObject: """ Asynchronous generator that fetch new events and return one as soon as one is available:: async for message in client.events(): channel = message['channel'] print(channel, ':', message) This method is different from :py:func:`BaseSalesforceStreaming.messages` because it filter messages and provide only those related to the channels you subscribed. """ async for message in self.messages(): if not message.get("channel", "").startswith("/meta/"): yield message
[docs] async def ask_stop(self) -> None: """ Ask client to stop receiving event:: async for event in client.events(): ... if ...: await client.ask_stop() This call will eventually stop :py:func:`BaseSalesforceStreaming.messages` and :py:func:`BaseSalesforceStreaming.events` async generator but this can take some time if not called inside the loop body: the generator will wait a timeout response from Salesforce server. """ self.should_stop = True
[docs] async def unsubscribe(self, channel: str) -> JSONList: """ Unsubscribe to a channel. Can be used directly:: await client.unsubscribe('/topic/Foo') This method, and the underlying protocol, are safe to be started as an background task:: loop.create_task(client.unsubscribe('/topic/Foo')) """ response = await self.send(await self.get_unsubscribe_payload(channel)) logger.info("Unsubscribe response: %r", response) return response
[docs] async def stop(self) -> None: """ Disconnect to Salesforce and close underlying connection. A best practice is to use async context manager interface that will call this method directly. """ await self.ask_stop() await self.disconnect() await self.close_session()
# -------------------- Connection logic -------------------- @property def token_url(self) -> str: """ The url that should be used to fetch an access token. """ url_prefix = "test" if self.sandbox else "login" return f"https://{url_prefix}.salesforce.com/services/oauth2/token"
[docs] @abc.abstractmethod async def fetch_token(self) -> Tuple[str, str]: """ Abstract coroutine method of connector that must provide an access token and the instance url linked. """
[docs] async def create_connected_session(self) -> aiohttp.ClientSession: """ This coroutine create an ``aiohttp.ClientSession`` using fetched token """ token, self.instance_url = await self.fetch_token() base_header = copy.deepcopy(self.base_header) base_header.update({"Authorization": f"Bearer {token}"}) session = aiohttp.ClientSession( connector=self.connector, headers=base_header, loop=self.loop ) return session
[docs] async def close_session(self) -> None: """ Close the underlying ``aiohttp.ClientSession`` connection """ if self.session is None: return await self.session.close() self.session = None
# -------------------- Bayeux/CometD logic layer -------------------- @property def end_point(self) -> str: """ Cometd endpoint """ return f"/cometd/{self.version}/"
[docs] async def get_handshake_payload(self) -> JSONObject: """ Provide the handshake payload """ return { "channel": "/meta/handshake", "supportedConnectionTypes": ["long-polling"], "version": "1.0", "minimumVersion": "1.0", }
[docs] async def get_subscribe_payload(self, channel: str) -> JSONObject: """ Provide the subscription payload for a specific channel """ return {"channel": "/meta/subscribe", "subscription": channel}
[docs] async def get_unsubscribe_payload(self, channel: str) -> JSONObject: """ Provide the unsubscription payload for a specific channel """ return {"channel": "/meta/unsubscribe", "subscription": channel}
[docs] async def send(self, data: JSONObject) -> JSONType: """ Send data to CometD server when the connection is established:: # Manually disconnect await client.send({'channel': '/meta/disconnect'}) """ self.message_count += 1 # Add id and client_id to payload data = copy.copy(data) data["id"] = str(self.message_count) if self.client_id: data["clientId"] = self.client_id # Post data return await self.post(self.end_point, json=data)
[docs] async def handshake(self) -> JSONList: """ Coroutine that perform an handshake (mandatory before any other action) """ self.message_count = 0 response = await self.send(await self.get_handshake_payload()) logger.info("Handshake response: %r", response) self.client_id = response[0]["clientId"] return response
[docs] async def disconnect(self) -> JSONList: """ Disconnect from the SF streaming server """ return await self.send({"channel": "/meta/disconnect"})
# -------------------- IO layer helpers --------------------
[docs] async def get(self, sub_url: str, **kwargs) -> JSONType: """ Perform a simple json get request from an internal url:: response = await.client.get('/myendpoint/') """ return await self.request("get", sub_url, **kwargs)
[docs] async def post(self, sub_url: str, **kwargs) -> JSONType: """ Perform a simple json post request from an internal url:: response = await.client.post('/myendpoint/', json={'data': 'foo'}) """ return await self.request("post", sub_url, **kwargs)
[docs] async def request(self, method: str, sub_url: str, **kwargs) -> JSONType: """ Perform a simple json request from an internal url """ url = f"{self.instance_url}{sub_url}" logger.debug("Perform %r to %r with %r", method, url, kwargs) async with self.session.request( method, url, timeout=self.timeout, **kwargs ) as resp: resp.raise_for_status() data = await resp.json() return data
# -------------------- SPECIALS METHODS -------------------- # @property def loop(self) -> asyncio.AbstractEventLoop: """ Running event loop """ if self._loop is None: self._loop = asyncio.get_event_loop() return self._loop # Asynchronous context manager async def __aenter__(self): await self.start() return self async def __aexit__(self, exc_type, exc, tb): await self.stop()