Skip to content
Snippets Groups Projects
client.py 7.54 KiB
"""Client for connecting to SETRA API"""
import logging
import urllib.parse
from typing import Union

import requests

from setra_client.models import Batch, CompleteBatch, BatchErrors

logger = logging.getLogger(__name__)


def merge_dicts(*dicts):
    """
    Combine a series of dicts without mutating any of them.

    >>> merge_dicts({'a': 1}, {'b': 2})
    {'a': 1, 'b': 2}
    >>> merge_dicts({'a': 1}, {'a': 2})
    {'a': 2}
    >>> merge_dicts(None, None, None)
    {}
    """
    combined = dict()
    for d in dicts:
        if not d:
            continue
        for k in d:
            combined[k] = d[k]
    return combined


class SetraEndpoints:
    def __init__(self,
                 url,
                 batch_url='api/batch/',
                 transaction_url='api/transaction/',
                 voucher_url='api/voucher/',
                 new_batch_url='api/addtrans/',
                 batch_complete_url='api/batch_complete/',
                 batch_error_url="api/batch_error/"
                 ):
        self.baseurl = url
        self.batch_url = batch_url
        self.transaction_url = transaction_url
        self.voucher_url = voucher_url
        self.new_batch_url = new_batch_url
        self.batch_complete_url = batch_complete_url
        self.batch_error_url = batch_error_url

    """ Get endpoints relative to the SETRA API URL. """

    def __repr__(self):
        return '{cls.__name__}({url!r})'.format(
            cls=type(self),
            url=self.baseurl)

    def batch(self, batch_id: str = None):
        """
        URL for Batch endpoint
        """
        if batch_id is None:
            return urllib.parse.urljoin(self.baseurl, self.batch_url)
        else:
            return urllib.parse.urljoin(self.baseurl,
                                        '/'.join((self.batch_url, batch_id)))

    def transaction(self, trans_id: str = None):
        """
        Url for Transaction endpoint
        """
        if trans_id is None:
            return urllib.parse.urljoin(self.baseurl, self.transaction_url)
        else:
            return urllib.parse.urljoin(self.baseurl,
                                        '/'.join((self.transaction_url,
                                                  trans_id)))

    def voucher(self, vouch_id: str = None):
        """
        Url for Voucher endpoint
        """
        if vouch_id is None:
            return urllib.parse.urljoin(self.baseurl, self.voucher_url)
        else:
            return urllib.parse.urljoin(self.baseurl,
                                        '/'.join((self.voucher_url, vouch_id)))

    def post_new_batch(self):
        return urllib.parse.urljoin(self.baseurl, self.new_batch_url)

    def batch_complete(self, batch_id: str):
        """
        URL for Batch endpoint
        """
        return urllib.parse.urljoin(self.baseurl,
                                    '/'.join((self.batch_complete_url, batch_id)))

    def batch_error(self, batch_id: str):
        """
        URL for batch_error endpoint
        """
        return urllib.parse.urljoin(
            self.baseurl, '/'.join((self.batch_error_url, batch_id))
        )


class SetraClient(object):
    default_headers = {
        'Accept': 'application/json',
    }

    def __init__(self,
                 url: str,
                 headers: Union[None, dict] = None,
                 return_objects: bool = True,
                 use_sessions: bool = True,
                 ):
        """
        SETRA API client.

        :param str url: Base API URL
        :param dict headers: Append extra headers to all requests
        :param bool return_objects: Return objects instead of raw JSON
        :param bool use_sessions: Keep HTTP connections alive (default True)
        """

        self.urls = SetraEndpoints(url)
        self.headers = merge_dicts(self.default_headers, headers)
        self.return_objects = return_objects
        if use_sessions:
            self.session = requests.Session()
        else:
            self.session = requests

    def _build_request_headers(self, headers):
        request_headers = {}
        for h in self.headers:
            request_headers[h] = self.headers[h]
        for h in (headers or ()):
            request_headers[h] = headers[h]
        return request_headers

    def call(self,
             method_name,
             url,
             headers=None,
             params=None,
             return_response=True,
             **kwargs):
        headers = self._build_request_headers(headers)
        if params is None:
            params = {}
        logger.debug('Calling %s %s with params=%r',
                     method_name,
                     urllib.parse.urlparse(url).path,
                     params)
        r = self.session.request(method_name,
                                 url,
                                 headers=headers,
                                 params=params,
                                 **kwargs)
        if r.status_code in (500, 400, 401):
            logger.warning('Got HTTP %d: %r', r.status_code, r.content)
        if return_response:
            return r
        r.raise_for_status()
        return r.json()  # Note: krasjer her, dersom man får text og ikke json i requesten, og return_response=false

    def get(self, url, **kwargs):
        return self.call('GET', url, **kwargs)

    def put(self, url, **kwargs):
        return self.call('PUT', url, **kwargs)

    def post(self, url, **kwargs):
        return self.call('POST', url, **kwargs)

    def object_or_data(self, cls, data) -> Union[object, dict]:
        if not self.return_objects:
            return data
        return cls.from_dict(data)

    def get_batch(self, batch_id: int = None):
        """
        GETs one or all batches from SETRA
        """
        if batch_id is not None:
            batch_id = str(batch_id)

        url = self.urls.batch(batch_id)

        response = self.get(url)
        return response.json()

    def get_voucher(self, vouch_id: int = None):
        """
        GETs one or all vouchers from SETRA
        """
        if vouch_id is not None:
            vouch_id = str(vouch_id)

        url = self.urls.voucher(vouch_id)
        response = self.get(url)
        return response.json()

    def get_transaction(self, trans_id: int = None):
        """
        GETs one or all transactions from SETRA
        """
        if trans_id is not None:
            trans_id = str(trans_id)

        url = self.urls.transaction(trans_id)
        response = self.get(url)
        return response.json()

    def post_new_batch(self, batchdata: Batch, return_response: bool = False):
        """
        POST combination of batch, vouchers and transactions
        """
        url = self.urls.post_new_batch()
        headers = {'Content-Type': 'application/json'}
        response = self.post(url,
                             data=batchdata.json(),
                             headers=headers,
                             return_response=return_response)
        return response

    def get_batch_complete(self, batch_id: str):
        """
        GETs complete batch (with vouchers and transactions)
        from SETRA
        """
        url = self.urls.batch_complete(batch_id)
        response = self.get(url)
        return self.object_or_data(CompleteBatch, response.json())

    def get_batch_errors(self, batch_id: str):
        url = self.urls.batch_error(batch_id)
        response = self.get(url)
        return self.object_or_data(BatchErrors, response.json())


def get_client(config_dict):
    """
    Get a SetraClient from configuration.
    """
    return SetraClient(**config_dict)