-
Notifications
You must be signed in to change notification settings - Fork 75
/
Copy pathutils.py
209 lines (160 loc) · 5.66 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
import dataclasses
import threading
import time
import uuid
from typing import Any, Iterable, List, Optional, Tuple, Union
from confluent_kafka import OFFSET_INVALID
from quixstreams.sinks import BatchingSink
from quixstreams.sinks.base import SinkBatch
from quixstreams.sinks.base.item import SinkItem
from quixstreams.sources import Source, StatefulSource
DEFAULT_TIMEOUT = 10.0
class Timeout:
"""
Utility class to create time-limited `while` loops.
It keeps track of the time passed since its creation, and checks if the timeout
expired on each `bool(Timeout)` check.
Use it while testing the `while` loops to make sure they exit at some point.
"""
def __init__(self, seconds: float = DEFAULT_TIMEOUT):
self._end = time.monotonic() + seconds
def __bool__(self):
expired = time.monotonic() >= self._end
if expired:
raise TimeoutError("Timeout expired")
return True
@dataclasses.dataclass
class TopicPartitionStub:
topic: str
partition: int
offset: int = OFFSET_INVALID
class ConfluentKafkaMessageStub:
"""
A stub object to mock `confluent_kafka.Message`.
Instances of `confluent_kafka.Message` cannot be directly created from Python,
see https://github.com/confluentinc/confluent-kafka-python/issues/1535.
"""
def __init__(
self,
topic: str = "test",
partition: int = 0,
offset: int = 0,
timestamp: Tuple[int, int] = (1, 123),
key: bytes = None,
value: bytes = None,
headers: Optional[List[Tuple[str, bytes]]] = None,
latency: float = None,
leader_epoch: int = None,
):
self._topic = topic
self._partition = partition
self._offset = offset
self._timestamp = timestamp
self._key = key
self._value = value
self._headers = headers
self._latency = latency
self._leader_epoch = leader_epoch
def headers(self, *args, **kwargs) -> Optional[List[Tuple[str, bytes]]]:
return self._headers
def key(self, *args, **kwargs) -> Optional[Union[str, bytes]]:
return self._key
def offset(self, *args, **kwargs) -> int:
return self._offset
def partition(self, *args, **kwargs) -> int:
return self._partition
def timestamp(self, *args, **kwargs) -> (int, int):
return self._timestamp
def topic(self, *args, **kwargs) -> str:
return self._topic
def value(self, *args, **kwargs) -> Optional[Union[str, bytes]]:
return self._value
def latency(self, *args, **kwargs) -> Optional[float]:
return self._latency
def leader_epoch(self, *args, **kwargs) -> Optional[int]:
return self._leader_epoch
def error(self) -> None:
return None
def __len__(self) -> int:
return len(self._value)
class DummySink(BatchingSink):
def __init__(self):
super().__init__()
self._results = []
def write(self, batch: SinkBatch):
for item in batch:
self._results.append(item)
@property
def results(self) -> List[SinkItem]:
return self._results
@property
def total_batched(self) -> int:
return sum(batch.size for batch in self._batches.values())
class DummySource(Source):
def __init__(
self,
name: Optional[str] = None,
values: Optional[List[Any]] = None,
finished: threading.Event = None,
error_in: Optional[Union[str, List[str]]] = None,
pickeable_error: bool = True,
) -> None:
super().__init__(name or str(uuid.uuid4()), 10)
self.key = "dummy"
self.values = values or []
self.finished = finished
self.error_in = error_in or []
self.pickeable_error = pickeable_error
def run(self):
self._produce()
if "run" in self.error_in:
self.error("test run error")
if self.finished:
self.finished.set()
while self.running:
time.sleep(0.1)
def _produce(self):
for value in self.values:
msg = self.serialize(key=self.key, value=value)
self.produce(value=msg.value, key=msg.key)
def cleanup(self, failed):
if "cleanup" in self.error_in:
self.error("test cleanup error")
super().cleanup(failed)
def stop(self):
if "stop" in self.error_in:
self.error("test stop error")
super().stop()
def error(self, msg):
if self.pickeable_error:
raise ValueError(msg)
else:
raise UnpickleableError(msg)
class DummyStatefulSource(DummySource, StatefulSource):
def __init__(
self,
name: Optional[str] = None,
values: Optional[Iterable[Any]] = None,
finished: threading.Event = None,
error_in: Optional[Union[str, List[str]]] = None,
pickeable_error: bool = True,
state_key: str = "",
assert_state_value: Any = None,
) -> None:
super().__init__(name, values, finished, error_in, pickeable_error)
self._state_key = state_key
self._assert_state_value = assert_state_value
def run(self):
if self._assert_state_value:
assert self._assert_state_value == self.state.get(self._state_key)
super().run()
def _produce(self):
for value in self.values:
msg = self.serialize(key=self.key, value=value)
self.produce(value=msg.value, key=msg.key)
self.state.set(self._state_key, value)
self.flush()
class UnpickleableError(Exception):
def __init__(self, *args: object) -> None:
# threading.Lock can't be pickled
self._ = threading.Lock()