From 79a35982a8454bf6b6e11c76e1a351b585bb0087 Mon Sep 17 00:00:00 2001
From: Petr Kalashnikov <pka065@it6100016.klientdrift.uib.no>
Date: Mon, 15 Mar 2021 12:41:42 +0100
Subject: [PATCH] Fix signature for batch extraction method

---
 setra_client/client.py | 18 +++++++++++-------
 setra_client/models.py | 19 ++++++++++++++++++-
 2 files changed, 29 insertions(+), 8 deletions(-)

diff --git a/setra_client/client.py b/setra_client/client.py
index e8c2ac9..4f10c10 100644
--- a/setra_client/client.py
+++ b/setra_client/client.py
@@ -1,11 +1,14 @@
 """Client for connecting to SETRA API"""
 import logging
 import urllib.parse
-from typing import Union
+from typing import Union, Optional
 
 import requests
 
-from setra_client.models import Batch, CompleteBatch, BatchErrors
+from setra_client.models import (Batch,
+                                 CompleteBatch,
+                                 BatchErrors,
+                                 BatchProgressEnum)
 
 logger = logging.getLogger(__name__)
 
@@ -182,11 +185,12 @@ class SetraClient(object):
         return cls.from_dict(data)
 
     def get_batch(self,
-                  batch_id: int = None,
-                  min_created_date: str = None,
-                  max_created_date: str = None,
-                  batch_progress: str = None,
-                  interface: str = None):
+                  batch_id: Optional[int] = None,
+                  min_created_date: Optional[str] = None,
+                  max_created_date: Optional[str] = None,
+                  batch_progress: Optional[BatchProgressEnum] = None,
+                  interface: Optional[str] = None):
+
         """
         GETs one or all batches from SETRA.
         Dates (maximal and minimal creation dates) should 
diff --git a/setra_client/models.py b/setra_client/models.py
index 195cf25..54d9594 100644
--- a/setra_client/models.py
+++ b/setra_client/models.py
@@ -3,7 +3,7 @@ import datetime
 import json
 import typing
 from typing import Optional, TypeVar
-
+from enum import Enum
 import pydantic
 
 NameType = TypeVar('NameType')
@@ -30,6 +30,23 @@ class BaseModel(pydantic.BaseModel):
         return cls.from_dict(data)
 
 
+class BatchProgressEnum(Enum):
+  CREATED = 'created'
+  VALIDATION_COMPLETED = 'validation_completed'
+  VALIDATION_FAILED = 'validation_failed'
+  SENT_TO_UBW = 'sent_to_ubw'
+  SEND_TO_UBW_FAILED = 'send_to_ubw_failed'
+  POLLING_COMPLETED = 'polling_completed'
+  POLLING_FAILED = 'polling_failed'
+  UBW_IMPORT_OK = 'ubw_import_ok'
+  UBW_IMPORT_FAILED = 'ubw_import_failed'
+  FETCH_FINAL_VOUCHERNO_COMPLETED = 'fetch_final_voucherno_completed'
+  FETCH_FINAL_VOUCHERNO_FAILED = 'fetch_final_voucherno_failed'
+
+  def __str__(self):
+    return str(self.value)
+
+
 class Transaction(BaseModel):
     account: str
     amount: float
-- 
GitLab