From 8e349a95f30977a213f5fe0b6a5072c08683e844 Mon Sep 17 00:00:00 2001 From: hhn <hhn@uio.no> Date: Mon, 24 Apr 2023 13:46:54 +0200 Subject: [PATCH] Add running mypy to pipeline --- .gitlab-ci.yml | 12 -- mypy.ini | 21 ++++ requirements-test.txt | 3 + setra_client/client.py | 241 +++++++++++++++++++++++++--------------- setra_client/models.py | 13 ++- setra_client/version.py | 2 +- setup.py | 3 +- tests/test_client.py | 9 +- tests/test_models.py | 3 + tox.ini | 3 +- 10 files changed, 198 insertions(+), 112 deletions(-) create mode 100644 mypy.ini diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 2df9383..b2cf261 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -22,18 +22,6 @@ json: stage: test script: tox -e json -python36: - image: python:3.6 - stage: test - script: - - tox -e py36 - -python37: - image: python:3.7 - stage: test - script: - - tox -e py37 - python38: image: python:3.8 stage: test diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..5248c3a --- /dev/null +++ b/mypy.ini @@ -0,0 +1,21 @@ +[mypy] +plugins = pydantic.mypy +warn_redundant_casts = true +warn_unused_ignores = true +warn_unused_configs = true +show_traceback = true +show_error_codes = true +strict_optional = true +strict_equality = true +no_implicit_optional = true +check_untyped_defs = true +disable_error_code = union-attr + +[mypy-setra_client.*] +disallow_untyped_calls = true +disallow_untyped_defs = true +disallow_incomplete_defs = true +warn_no_return = true +disallow_any_generics = true +warn_return_any = true +disallow_any_unimported = true \ No newline at end of file diff --git a/requirements-test.txt b/requirements-test.txt index b74d7b0..9506b76 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,3 +1,6 @@ pytest requests_mock tox +mypy +types-setuptools +types-requests diff --git a/setra_client/client.py b/setra_client/client.py index e8fbc99..87c0664 100644 --- a/setra_client/client.py +++ b/setra_client/client.py @@ -2,7 +2,8 @@ import logging import urllib.parse from datetime import date -from typing import Union, Optional, List +from types import ModuleType +from typing import Any, Dict, List, Literal, Optional, overload, Tuple, Type, Union import requests @@ -19,12 +20,15 @@ from setra_client.models import ( AbwOrderErrors, ResponseStatusEnum, Voucher, + T, ) +JsonType = Any + logger = logging.getLogger(__name__) -def merge_dicts(*dicts): +def merge_dicts(*dicts: Optional[Dict[Any, Any]]) -> Dict[Any, Any]: """ Combine a series of dicts without mutating any of them. @@ -51,24 +55,24 @@ class IncorrectPathError(Exception): class SetraEndpoints: def __init__( self, - url, - batch_url="api/batch/", - transaction_url="api/transaction/", - voucher_url="api/voucher/", - new_batch_url="api/addtrans/", - put_batch_url="api/addtrans/", - batch_complete_url="api/batch_complete/", - batch_error_url="api/batch_error/", - parameters_url="api/parameters/", + url: str, + batch_url: str = "api/batch/", + transaction_url: str = "api/transaction/", + voucher_url: str = "api/voucher/", + new_batch_url: str = "api/addtrans/", + put_batch_url: str = "api/addtrans/", + batch_complete_url: str = "api/batch_complete/", + batch_error_url: str = "api/batch_error/", + parameters_url: str = "api/parameters/", # Order urls (sotra): - order_url="api/order/", - order_complete_url="api/order_complete/", - detail_url="api/detail/", - details_in_order_url="api/details_in_order/", - abw_order_complete_url="api/abw_order_complete/", - post_add_abw_order_url="api/add_abw_order/", - abw_order_errors_url="api/abw_order_errors/", - ): + order_url: str = "api/order/", + order_complete_url: str = "api/order_complete/", + detail_url: str = "api/detail/", + details_in_order_url: str = "api/details_in_order/", + abw_order_complete_url: str = "api/abw_order_complete/", + post_add_abw_order_url: str = "api/add_abw_order/", + abw_order_errors_url: str = "api/abw_order_errors/", + ) -> None: self.baseurl = url self.batch_url = batch_url self.transaction_url = transaction_url @@ -89,10 +93,10 @@ class SetraEndpoints: """ Get endpoints relative to the SETRA API URL. """ - def __repr__(self): + def __repr__(self) -> str: return "{cls.__name__}({url!r})".format(cls=type(self), url=self.baseurl) - def batch(self, batch_id: str = None): + def batch(self, batch_id: Optional[str] = None) -> str: """ URL for Batch endpoint """ @@ -103,7 +107,7 @@ class SetraEndpoints: self.baseurl, "/".join((self.batch_url, batch_id)) ) - def transaction(self, trans_id: str = None): + def transaction(self, trans_id: Optional[str] = None) -> str: """ Url for Transaction endpoint """ @@ -114,7 +118,7 @@ class SetraEndpoints: self.baseurl, "/".join((self.transaction_url, trans_id)) ) - def voucher(self, vouch_id: str = None): + def voucher(self, vouch_id: Optional[str] = None) -> str: """ Url for Voucher endpoint """ @@ -125,13 +129,13 @@ class SetraEndpoints: self.baseurl, "/".join((self.voucher_url, vouch_id)) ) - def post_new_batch(self): + def post_new_batch(self) -> str: return urllib.parse.urljoin(self.baseurl, self.new_batch_url) - def put_update_batch(self): + def put_update_batch(self) -> str: return urllib.parse.urljoin(self.baseurl, self.new_batch_url) - def batch_complete(self, batch_id: str): + def batch_complete(self, batch_id: str) -> str: """ URL for Batch endpoint """ @@ -139,7 +143,7 @@ class SetraEndpoints: self.baseurl, "/".join((self.batch_complete_url, batch_id)) ) - def batch_error(self, batch_id: str): + def batch_error(self, batch_id: str) -> str: """ URL for batch_error endpoint """ @@ -147,25 +151,27 @@ class SetraEndpoints: self.baseurl, "/".join((self.batch_error_url, batch_id)) ) - def parameters(self): + def parameters(self) -> str: """Get url for parameters endpoint""" return urllib.parse.urljoin(self.baseurl, self.parameters_url) - def abw_order_complete(self, abw_order_id: str = None): + def abw_order_complete(self, abw_order_id: Optional[str] = None) -> str: """ URL for getting an abw order including a list of its orders, and each order contains a list of its detail objects """ + if abw_order_id is None: + raise ValueError(f"Illegal value for abw_order_id: {abw_order_id}") return urllib.parse.urljoin( self.baseurl, "/".join((self.abw_order_complete_url, abw_order_id)) ) - def add_abw_order(self): + def add_abw_order(self) -> str: """ URL for posting a new complete abw order containing order and detail objects """ return urllib.parse.urljoin(self.baseurl, self.post_add_abw_order_url) - def order(self, order_id: str = None): + def order(self, order_id: Optional[str] = None) -> str: """ URL for order endpoint """ @@ -176,7 +182,7 @@ class SetraEndpoints: self.baseurl, "/".join((self.order_url, order_id)) ) - def order_complete(self, order_id: str): + def order_complete(self, order_id: str) -> str: """ URL for getting an order object including a list of its detail objects """ @@ -184,7 +190,7 @@ class SetraEndpoints: self.baseurl, "/".join((self.order_complete_url, order_id)) ) - def detail(self, detail_id: str): + def detail(self, detail_id: str) -> str: """ URL for detail endpoint """ @@ -192,7 +198,7 @@ class SetraEndpoints: self.baseurl, "/".join((self.detail_url, detail_id)) ) - def details_in_order(self, order_id: str): + def details_in_order(self, order_id: str) -> str: """ URL for getting a list of detail objects in an order """ @@ -200,7 +206,7 @@ class SetraEndpoints: self.baseurl, "/".join((self.details_in_order_url, order_id)) ) - def abw_order_errors(self, abw_order_id: str): + def abw_order_errors(self, abw_order_id: str) -> str: """ URL for getting an object containing lists of all errors for AbwOrder, Orders and Details """ @@ -217,7 +223,7 @@ class SetraClient(object): def __init__( self, url: str, - headers: Union[None, dict] = None, + headers: Optional[Dict[str, str]] = None, return_objects: bool = True, use_sessions: bool = True, ): @@ -233,28 +239,56 @@ class SetraClient(object): self.urls = SetraEndpoints(url) self.headers = merge_dicts(self.default_headers, headers) self.return_objects = return_objects + self.session: Union[ModuleType, requests.Session] if use_sessions: self.session = requests.Session() else: self.session = requests - def _build_request_headers(self, headers): - request_headers = {} + def _build_request_headers( + self, headers: Optional[Dict[str, Any]] + ) -> Dict[str, Any]: + request_headers: Dict[str, Any] = dict() for h in self.headers: request_headers[h] = self.headers[h] - for h in headers or (): - request_headers[h] = headers[h] + if headers is not None: + for h in headers: + request_headers[h] = headers[h] return request_headers + @overload def call( self, - method_name, - url, - headers=None, - params=None, - return_response=True, - **kwargs, - ): + method_name: Literal["GET"], + url: str, + headers: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + return_response: Optional[bool] = True, + **kwargs: Any, + ) -> Union[JsonType, str, None]: + ... + + @overload + def call( + self, + method_name: str, + url: str, + headers: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + return_response: Optional[bool] = True, + **kwargs: Any, + ) -> Union[JsonType, str, requests.models.Response, None]: + ... + + def call( + self, + method_name: str, + url: str, + headers: Optional[Dict[str, Any]] = None, + params: Optional[Dict[str, Any]] = None, + return_response: Optional[bool] = True, + **kwargs: Any, + ) -> Union[JsonType, str, requests.models.Response, None]: if method_name == "GET": return_response = False headers = self._build_request_headers(headers) @@ -296,16 +330,22 @@ class SetraClient(object): return_data = r.text return return_data - def get(self, url, **kwargs): + def get(self, url: str, **kwargs: Any) -> Union[JsonType, str, None]: return self.call("GET", url, **kwargs) - def put(self, url, **kwargs): + def put( + self, url: str, **kwargs: Any + ) -> Union[JsonType, str, requests.models.Response, None]: return self.call("PUT", url, **kwargs) - def post(self, url, **kwargs): + def post( + self, url: str, **kwargs: Any + ) -> Union[JsonType, str, requests.models.Response, None]: return self.call("POST", url, **kwargs) - def object_or_data(self, cls, data: Union[dict, List[dict]]) -> Union[object, dict]: + def object_or_data( + self, cls: Type[T], data: Union[Dict[Any, Any], List[Dict[Any, Any]]] + ) -> Union[Dict[Any, Any], List[Dict[Any, Any]], T, List[T]]: """Create list of objects or return data as is""" if not self.return_objects: return data @@ -321,7 +361,7 @@ class SetraClient(object): max_created_date: Optional[date] = None, batch_progress: Optional[BatchProgressEnum] = None, interface: Optional[str] = None, - ) -> Union[List[dict], dict, List[OutputBatch]]: + ) -> Union[JsonType, str, List[OutputBatch], None]: """ Search batches in SETRA. Dates (maximal and minimal creation dates) should @@ -340,27 +380,29 @@ class SetraClient(object): url = self.urls.batch() data = self.get(url, params=params) - if self.return_objects: - if isinstance(data, list): - return [OutputBatch(**item) for item in data] - elif isinstance(data, dict): - return [OutputBatch(**data)] - else: + if not self.return_objects: return data + if isinstance(data, list): + return [OutputBatch(**item) for item in data] + elif isinstance(data, dict): + return [OutputBatch(**data)] + return None - def get_batch(self, batch_id: str) -> Union[OutputBatch, dict]: + def get_batch(self, batch_id: str) -> Union[JsonType, str, OutputBatch, None]: """ GETs a batch from SETRA. """ url = self.urls.batch(str(batch_id)) - data = self.get(url) + data: Any = self.get(url) if data and self.return_objects: return OutputBatch(**data) else: return data - def get_voucher(self, vouch_id: int = None): + def get_voucher( + self, vouch_id: Union[str, int, None] = None + ) -> Union[JsonType, str, None]: """ GETs one or all vouchers from SETRA """ @@ -371,7 +413,9 @@ class SetraClient(object): data = self.get(url) return data - def get_transaction(self, trans_id: int = None): + def get_transaction( + self, trans_id: Union[str, int, None] = None + ) -> Union[JsonType, str, None]: """ GETs one or all transactions from SETRA """ @@ -382,7 +426,9 @@ class SetraClient(object): data = self.get(url) return data - def post_new_batch(self, batchdata: InputBatch): + def post_new_batch( + self, batchdata: InputBatch + ) -> Tuple[str, Dict[str, Union[int, JsonType, bytes, None]]]: """ POST combination of batch, vouchers and transactions """ @@ -409,7 +455,9 @@ class SetraClient(object): "content": content, } - def put_update_batch(self, batchdata: InputBatch): + def put_update_batch( + self, batchdata: InputBatch + ) -> Tuple[str, Dict[str, Union[int, JsonType, bytes, None]]]: """ PUT updates an existing batch with vouchers and transactions, if the batch exists in setra, and has status=created, or validation failed. @@ -440,62 +488,74 @@ class SetraClient(object): "content": content, } - def get_batch_complete(self, batch_id: str): + def get_batch_complete( + self, batch_id: str + ) -> Union[ + CompleteBatch, List[CompleteBatch], Dict[Any, Any], List[Dict[Any, Any]] + ]: """ GETs complete batch (with vouchers and transactions) from SETRA """ url = self.urls.batch_complete(batch_id) - data = self.get(url) + data: Any = self.get(url) return self.object_or_data(CompleteBatch, data) - def get_batch_errors(self, batch_id: str): + def get_batch_errors( + self, batch_id: str + ) -> Union[BatchErrors, List[BatchErrors], Dict[Any, Any], List[Dict[Any, Any]]]: url = self.urls.batch_error(batch_id) - data = self.get(url) + data: Any = self.get(url) return self.object_or_data(BatchErrors, data) - def get_parameters(self, interface: str = None): + def get_parameters( + self, interface: Optional[str] = None + ) -> Union[Parameter, List[Parameter], Dict[Any, Any], List[Dict[Any, Any]]]: """Make a GET request to the parameters endpoint""" url = self.urls.parameters() queryparams = None if interface: queryparams = {"interface": interface} - data = self.get(url, params=queryparams) + data: Any = self.get(url, params=queryparams) return self.object_or_data(Parameter, data) # Order ("Sotra") functions: - def get_order(self, order_id: str) -> Union[Order, dict]: + def get_order(self, order_id: str) -> Union[Order, JsonType, str, None]: """ GETs one order object """ url = self.urls.order(str(order_id)) - data = self.get(url) + data: Any = self.get(url) if self.return_objects: return Order(**data) else: return data - def get_detail(self, detail_id: str) -> Union[Detail, dict]: + def get_detail(self, detail_id: str) -> Union[Detail, JsonType, str, None]: """ GETs one detail object """ url = self.urls.detail(str(detail_id)) - data = self.get(url) + data: Any = self.get(url) if self.return_objects: return Detail(**data) else: return data - def get_order_list(self): + def get_order_list( + self, + ) -> Union[Order, List[Order], Dict[Any, Any], List[Dict[Any, Any]]]: """ GETs a list of all orders, without detail objects """ url = self.urls.order() - data = self.get(url) + data: Any = self.get(url) return self.object_or_data(Order, data) - def get_details_in_order(self, order_id: str) -> Union[List[Detail], dict]: + def get_details_in_order( + self, order_id: str + ) -> Union[List[Detail], Any, str, None]: """ GETs list of all detail objects belonging to an order """ @@ -506,32 +566,35 @@ class SetraClient(object): return [Detail(**item) for item in data] elif isinstance(data, dict): return [Detail(**data)] - else: - return data + return data - def get_order_complete(self, order_id: str) -> Union[Order, dict]: + def get_order_complete(self, order_id: str) -> Union[Order, JsonType, str, None]: """ GETs one order, with all detail objects """ url = self.urls.order_complete(str(order_id)) - data = self.get(url) + data: Any = self.get(url) if self.return_objects: return Order(**data) else: return data - def get_abw_order_complete(self, abw_order_id: str) -> Union[AbwOrder, dict]: + def get_abw_order_complete( + self, abw_order_id: str + ) -> Union[AbwOrder, JsonType, str, None]: """ GETs one abworder, with all order and detail objects """ url = self.urls.abw_order_complete(str(abw_order_id)) - data = self.get(url) + data: Any = self.get(url) if self.return_objects: return AbwOrder(**data) else: return data - def post_add_abw_order(self, abworder: AbwOrder): + def post_add_abw_order( + self, abworder: AbwOrder + ) -> Tuple[str, Dict[str, Union[int, JsonType, bytes]]]: """ POST one AbwOrder, with its orders and its details. @@ -567,16 +630,20 @@ class SetraClient(object): "content": content, } - def get_abw_order_errors(self, abw_order_id: str): + def get_abw_order_errors( + self, abw_order_id: str + ) -> Union[ + AbwOrderErrors, List[AbwOrderErrors], Dict[Any, Any], List[Dict[Any, Any]] + ]: """ Gets an object containing three lists of all errors, for each of AbwOrder, Orders and Details """ url = self.urls.abw_order_errors(abw_order_id) - response = self.get(url) + response: Any = self.get(url) return self.object_or_data(AbwOrderErrors, response.json()) -def get_client(config_dict): +def get_client(config_dict: Dict[str, Any]) -> SetraClient: """ Get a SetraClient from configuration. """ diff --git a/setra_client/models.py b/setra_client/models.py index 0fa2645..0c81e0b 100644 --- a/setra_client/models.py +++ b/setra_client/models.py @@ -2,7 +2,7 @@ import datetime import json from enum import Enum -from typing import Optional, List +from typing import Any, Dict, List, Optional, Type, TypeVar import pydantic @@ -12,16 +12,19 @@ def to_lower_camel(s: str) -> str: return "".join([first.lower(), *map(str.capitalize, others)]) +T = TypeVar("T", bound="BaseModel") + + class BaseModel(pydantic.BaseModel): """Expanded BaseModel for convenience""" @classmethod - def from_dict(cls, data: dict): + def from_dict(cls: Type[T], data: Dict[Any, Any]) -> T: """Initialize class from dict""" return cls(**data) @classmethod - def from_json(cls, json_data: str): + def from_json(cls: Type[T], json_data: str) -> T: """Initialize class from json file""" data = json.loads(json_data) return cls.from_dict(data) @@ -40,7 +43,7 @@ class BatchProgressEnum(str, Enum): FETCH_FINAL_VOUCHERNO_COMPLETED = "fetch_final_voucherno_completed" FETCH_FINAL_VOUCHERNO_FAILED = "fetch_final_voucherno_failed" - def __str__(self): + def __str__(self) -> str: return str(self.value) @@ -49,7 +52,7 @@ class ResponseStatusEnum(str, Enum): CONFLICT = "Conflict" UNKNOWN = "Unknown" - def __str__(self): + def __str__(self) -> str: return str(self.value) diff --git a/setra_client/version.py b/setra_client/version.py index 6071371..bd43526 100644 --- a/setra_client/version.py +++ b/setra_client/version.py @@ -7,7 +7,7 @@ import pkg_resources DISTRIBUTION_NAME = "setra-client" -def get_distribution(): +def get_distribution() -> pkg_resources.Distribution: """Get the distribution object for this single module dist.""" try: return pkg_resources.get_distribution(DISTRIBUTION_NAME) diff --git a/setup.py b/setup.py index 665399c..9246fc2 100644 --- a/setup.py +++ b/setup.py @@ -70,6 +70,7 @@ def run_setup(): author_email="bnt-int@usit.uio.no", use_scm_version=True, packages=get_packages(), + package_data={"setra_client": ["py.typed"]}, setup_requires=setup_requirements, install_requires=install_requirements, tests_require=test_requirements, @@ -81,8 +82,6 @@ def run_setup(): "Intended Audience :: Developers", "Topic :: Software Development :: Libraries", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", ], keywords="SETRA API client", diff --git a/tests/test_client.py b/tests/test_client.py index d6132ac..a7205f2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -5,6 +5,7 @@ from json.decoder import JSONDecodeError import pytest import requests from requests import HTTPError +from typing import Any, Dict from setra_client.client import SetraClient from setra_client.client import SetraEndpoints @@ -37,14 +38,14 @@ def client_cls(header_name): def test_init_does_not_mutate_arg(client_cls, baseurl): - headers = {} + headers: Dict[str, Any] = {} client = client_cls(baseurl, headers=headers) assert headers is not client.headers assert not headers def test_init_applies_default_headers(client_cls, baseurl, header_name): - headers = {} + headers: Dict[str, Any] = {} client = client_cls(baseurl, headers=headers) assert header_name in client.headers assert client.headers[header_name] == client.default_headers[header_name] @@ -587,5 +588,5 @@ def test_send_in_abworder_failure( response = client.post_add_abw_order(abworder) assert isinstance(response, tuple) - assert (response[0], "Unknown") - assert (response[1]["code"], 404) + assert response[0] == "Unknown" + assert response[1]["code"] == 404 diff --git a/tests/test_models.py b/tests/test_models.py index f46cfb8..89884cf 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -106,6 +106,7 @@ def test_detail_list(detail_list_fixture): def test_order_with_detail(order_with_detail_fixture): order = Order(**order_with_detail_fixture) + assert order.details is not None assert len(order.details) == 2 assert order.details[0].koststed == "123123" assert order.details[1].koststed == "444" @@ -114,6 +115,8 @@ def test_order_with_detail(order_with_detail_fixture): def test_complete_abw_order(complete_abw_order_fixture): abworder = AbwOrder(**complete_abw_order_fixture) assert len(abworder.orders) == 2 + assert abworder.orders[0].details is not None + assert abworder.orders[1].details is not None assert len(abworder.orders[0].details) == 1 assert len(abworder.orders[1].details) == 1 assert abworder.orders[0].id == 3 diff --git a/tox.ini b/tox.ini index 7686b09..2bbe548 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py36,py37,py38 +envlist = py38 [testenv] description = Run tests with {basepython} @@ -8,6 +8,7 @@ deps = -rrequirements.txt commands = {envpython} -m pytest --junitxml=junit-{envname}.xml {posargs} + {envpython} -m mypy --config-file mypy.ini setra_client tests [pytest] xfail_strict = true -- GitLab