-
Notifications
You must be signed in to change notification settings - Fork 114
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add cancel token for use in p2p service (#136)
- Loading branch information
Showing
5 changed files
with
360 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# cancel-token for QuarkChain services | ||
see https://github.com/ethereum/asyncio-cancel-token | ||
|
||
ported to work with python3.5 (pypy) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
class BaseCancelTokenException(Exception): | ||
""" | ||
Base exception class for the `asyncio-cancel-token` library. | ||
""" | ||
|
||
pass | ||
|
||
|
||
class EventLoopMismatch(BaseCancelTokenException): | ||
""" | ||
Raised when two different asyncio event loops are referenced, but must be equal | ||
""" | ||
|
||
pass | ||
|
||
|
||
class OperationCancelled(BaseCancelTokenException): | ||
""" | ||
Raised when an operation was cancelled. | ||
""" | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import asyncio | ||
import functools | ||
|
||
import pytest | ||
|
||
from quarkchain.p2p.cancel_token.token import ( | ||
CancelToken, | ||
EventLoopMismatch, | ||
OperationCancelled, | ||
) | ||
|
||
|
||
def test_token_single(): | ||
token = CancelToken("token") | ||
assert not token.triggered | ||
token.trigger() | ||
assert token.triggered | ||
assert token.triggered_token == token | ||
|
||
|
||
def test_token_chain_event_loop_mismatch(): | ||
token = CancelToken("token") | ||
token2 = CancelToken("token2", loop=asyncio.new_event_loop()) | ||
with pytest.raises(EventLoopMismatch): | ||
token.chain(token2) | ||
|
||
|
||
def test_token_chain_trigger_chain(): | ||
token = CancelToken("token") | ||
token2 = CancelToken("token2") | ||
token3 = CancelToken("token3") | ||
intermediate_chain = token.chain(token2) | ||
chain = intermediate_chain.chain(token3) | ||
assert not chain.triggered | ||
chain.trigger() | ||
assert chain.triggered | ||
assert not intermediate_chain.triggered | ||
assert chain.triggered_token == chain | ||
assert not token.triggered | ||
assert not token2.triggered | ||
assert not token3.triggered | ||
|
||
|
||
def test_token_chain_trigger_first(): | ||
token = CancelToken("token") | ||
token2 = CancelToken("token2") | ||
token3 = CancelToken("token3") | ||
chain = token.chain(token2).chain(token3) | ||
assert not chain.triggered | ||
token.trigger() | ||
assert chain.triggered | ||
assert chain.triggered_token == token | ||
|
||
|
||
def test_token_chain_trigger_middle(): | ||
token = CancelToken("token") | ||
token2 = CancelToken("token2") | ||
token3 = CancelToken("token3") | ||
intermediate_chain = token.chain(token2) | ||
chain = intermediate_chain.chain(token3) | ||
assert not chain.triggered | ||
token2.trigger() | ||
assert chain.triggered | ||
assert intermediate_chain.triggered | ||
assert chain.triggered_token == token2 | ||
assert not token3.triggered | ||
assert not token.triggered | ||
|
||
|
||
def test_token_chain_trigger_last(): | ||
token = CancelToken("token") | ||
token2 = CancelToken("token2") | ||
token3 = CancelToken("token3") | ||
intermediate_chain = token.chain(token2) | ||
chain = intermediate_chain.chain(token3) | ||
assert not chain.triggered | ||
token3.trigger() | ||
assert chain.triggered | ||
assert chain.triggered_token == token3 | ||
assert not intermediate_chain.triggered | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_token_wait(event_loop): | ||
token = CancelToken("token") | ||
event_loop.call_soon(token.trigger) | ||
done, pending = await asyncio.wait([token.wait()], timeout=0.1) | ||
assert len(done) == 1 | ||
assert len(pending) == 0 | ||
assert token.triggered | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_wait_cancel_pending_tasks_on_completion(event_loop): | ||
token = CancelToken("token") | ||
token2 = CancelToken("token2") | ||
chain = token.chain(token2) | ||
event_loop.call_soon(token2.trigger) | ||
await chain.wait() | ||
await assert_only_current_task_not_done() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_wait_cancel_pending_tasks_on_cancellation(event_loop): | ||
"""Test that cancelling a pending CancelToken.wait() coroutine doesn't leave .wait() | ||
coroutines for any chained tokens behind. | ||
""" | ||
token = ( | ||
CancelToken("token").chain(CancelToken("token2")).chain(CancelToken("token3")) | ||
) | ||
token_wait_coroutine = token.wait() | ||
done, pending = await asyncio.wait([token_wait_coroutine], timeout=0.1) | ||
assert len(done) == 0 | ||
assert len(pending) == 1 | ||
pending_task = pending.pop() | ||
assert pending_task._coro == token_wait_coroutine | ||
pending_task.cancel() | ||
await assert_only_current_task_not_done() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_cancellable_wait(event_loop): | ||
fut = asyncio.Future() | ||
event_loop.call_soon(functools.partial(fut.set_result, "result")) | ||
result = await CancelToken("token").cancellable_wait(fut, timeout=1) | ||
assert result == "result" | ||
await assert_only_current_task_not_done() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_cancellable_wait_future_exception(event_loop): | ||
fut = asyncio.Future() | ||
event_loop.call_soon(functools.partial(fut.set_exception, Exception())) | ||
with pytest.raises(Exception): | ||
await CancelToken("token").cancellable_wait(fut, timeout=1) | ||
await assert_only_current_task_not_done() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_cancellable_wait_cancels_subtasks_when_cancelled(event_loop): | ||
token = CancelToken("") | ||
future = asyncio.ensure_future(token.cancellable_wait(asyncio.sleep(2))) | ||
with pytest.raises(asyncio.TimeoutError): | ||
# asyncio.wait_for() will timeout and then cancel our cancellable_wait() future, but | ||
# Task.cancel() doesn't immediately cancels the task | ||
# (https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel), so we need | ||
# the sleep below before we check that the task is actually cancelled. | ||
await asyncio.wait_for(future, timeout=0.01) | ||
await asyncio.sleep(0) | ||
assert future.cancelled() | ||
await assert_only_current_task_not_done() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_cancellable_wait_timeout(): | ||
with pytest.raises(TimeoutError): | ||
await CancelToken("token").cancellable_wait(asyncio.sleep(0.02), timeout=0.01) | ||
await assert_only_current_task_not_done() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_cancellable_wait_operation_cancelled(event_loop): | ||
token = CancelToken("token") | ||
token.trigger() | ||
with pytest.raises(OperationCancelled): | ||
await token.cancellable_wait(asyncio.sleep(0.02)) | ||
await assert_only_current_task_not_done() | ||
|
||
|
||
async def assert_only_current_task_not_done(): | ||
# This sleep() is necessary because Task.cancel() doesn't immediately cancels the task: | ||
# https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel | ||
await asyncio.sleep(0.01) | ||
for task in asyncio.Task.all_tasks(): | ||
if task == asyncio.Task.current_task(): | ||
# This is the task for this very test, so it will be running | ||
assert not task.done() | ||
else: | ||
assert task.done() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
import asyncio | ||
from typing import Any, Awaitable, List, Sequence, TypeVar, cast # noqa: F401 | ||
|
||
from .exceptions import EventLoopMismatch, OperationCancelled | ||
|
||
_R = TypeVar("_R") | ||
|
||
|
||
class CancelToken: | ||
def __init__(self, name: str, loop: asyncio.AbstractEventLoop = None) -> None: | ||
self.name = name | ||
self._chain = [] # : List['CancelToken'] | ||
self._triggered = asyncio.Event(loop=loop) | ||
self._loop = loop | ||
|
||
@property | ||
def loop(self) -> asyncio.AbstractEventLoop: | ||
""" | ||
Return the `loop` that this token is bound to. | ||
""" | ||
return self._loop | ||
|
||
def chain(self, token: "CancelToken") -> "CancelToken": | ||
""" | ||
Return a new CancelToken chaining this and the given token. | ||
The new CancelToken's triggered will return True if trigger() has been | ||
called on either of the chained tokens, but calling trigger() on the new token | ||
has no effect on either of the chained tokens. | ||
""" | ||
if self.loop != token._loop: | ||
raise EventLoopMismatch( | ||
"Chained CancelToken objects must be on the same event loop" | ||
) | ||
chain_name = ":".join([self.name, token.name]) | ||
chain = CancelToken(chain_name, loop=self.loop) | ||
chain._chain.extend([self, token]) | ||
return chain | ||
|
||
def trigger(self) -> None: | ||
""" | ||
Trigger this cancel token and any child tokens that have been chained with it. | ||
""" | ||
self._triggered.set() | ||
|
||
@property | ||
def triggered_token(self) -> "CancelToken": | ||
""" | ||
Return the token which was triggered. | ||
The returned token may be this token or one that it was chained with. | ||
""" | ||
if self._triggered.is_set(): | ||
return self | ||
for token in self._chain: | ||
if token.triggered: | ||
# Use token.triggered_token here to make the lookup recursive as self._chain may | ||
# contain other chains. | ||
return token.triggered_token | ||
return None | ||
|
||
@property | ||
def triggered(self) -> bool: | ||
""" | ||
Return `True` or `False` whether this token has been triggered. | ||
""" | ||
if self._triggered.is_set(): | ||
return True | ||
return any(token.triggered for token in self._chain) | ||
|
||
def raise_if_triggered(self) -> None: | ||
""" | ||
Raise `OperationCancelled` if this token has been triggered. | ||
""" | ||
if self.triggered: | ||
raise OperationCancelled( | ||
"Cancellation requested by {} token".format(self.triggered_token) | ||
) | ||
|
||
async def wait(self) -> None: | ||
""" | ||
Coroutine which returns when this token has been triggered | ||
""" | ||
if self.triggered_token is not None: | ||
return | ||
|
||
futures = [asyncio.ensure_future(self._triggered.wait(), loop=self.loop)] | ||
for token in self._chain: | ||
futures.append(asyncio.ensure_future(token.wait(), loop=self.loop)) | ||
|
||
def cancel_not_done(fut: "asyncio.Future[None]") -> None: | ||
for future in futures: | ||
if not future.done(): | ||
future.cancel() | ||
|
||
async def _wait_for_first(futures: Sequence[Awaitable[Any]]) -> None: | ||
for future in asyncio.as_completed(futures): | ||
# We don't need to catch CancelledError here (and cancel not done futures) | ||
# because our callback (above) takes care of that. | ||
await cast(Awaitable[Any], future) | ||
return | ||
|
||
fut = asyncio.ensure_future(_wait_for_first(futures), loop=self.loop) | ||
fut.add_done_callback(cancel_not_done) | ||
await fut | ||
|
||
async def cancellable_wait( | ||
self, *awaitables: Awaitable[_R], timeout: float = None | ||
) -> _R: | ||
""" | ||
Wait for the first awaitable to complete, unless we timeout or the | ||
token is triggered. | ||
Returns the result of the first awaitable to complete. | ||
Raises TimeoutError if we timeout or | ||
`~cancel_token.exceptions.OperationCancelled` if the cancel token is | ||
triggered. | ||
All pending futures are cancelled before returning. | ||
""" | ||
futures = [ | ||
asyncio.ensure_future(a, loop=self.loop) | ||
for a in awaitables + (self.wait(),) | ||
] | ||
try: | ||
done, pending = await asyncio.wait( | ||
futures, | ||
timeout=timeout, | ||
return_when=asyncio.FIRST_COMPLETED, | ||
loop=self.loop, | ||
) | ||
except asyncio.futures.CancelledError: | ||
# Since we use return_when=asyncio.FIRST_COMPLETED above, we can be sure none of our | ||
# futures will be done here, so we don't need to check if any is done before cancelling. | ||
for future in futures: | ||
future.cancel() | ||
raise | ||
for task in pending: | ||
task.cancel() | ||
if not done: | ||
raise TimeoutError() | ||
if self.triggered_token is not None: | ||
# We've been asked to cancel so we don't care about our future, but we must | ||
# consume its exception or else asyncio will emit warnings. | ||
for task in done: | ||
task.exception() | ||
raise OperationCancelled( | ||
"Cancellation requested by {} token".format(self.triggered_token) | ||
) | ||
return done.pop().result() | ||
|
||
def __str__(self) -> str: | ||
return self.name | ||
|
||
def __repr__(self) -> str: | ||
return "<CancelToken: {0}>".format(self.name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters