Skip to content

Commit

Permalink
add cancel token for use in p2p service (#136)
Browse files Browse the repository at this point in the history
  • Loading branch information
qcdll authored Oct 4, 2018
1 parent 9e6cabb commit 8fd0059
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 0 deletions.
4 changes: 4 additions & 0 deletions quarkchain/p2p/cancel_token/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# cancel-token for QuarkChain services
see https://github.com/ethereum/asyncio-cancel-token

ported to work with python3.5 (pypy)
22 changes: 22 additions & 0 deletions quarkchain/p2p/cancel_token/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
class BaseCancelTokenException(Exception):
"""
Base exception class for the `asyncio-cancel-token` library.
"""

pass


class EventLoopMismatch(BaseCancelTokenException):
"""
Raised when two different asyncio event loops are referenced, but must be equal
"""

pass


class OperationCancelled(BaseCancelTokenException):
"""
Raised when an operation was cancelled.
"""

pass
179 changes: 179 additions & 0 deletions quarkchain/p2p/cancel_token/tests/test_cancel_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import asyncio
import functools

import pytest

from quarkchain.p2p.cancel_token.token import (
CancelToken,
EventLoopMismatch,
OperationCancelled,
)


def test_token_single():
token = CancelToken("token")
assert not token.triggered
token.trigger()
assert token.triggered
assert token.triggered_token == token


def test_token_chain_event_loop_mismatch():
token = CancelToken("token")
token2 = CancelToken("token2", loop=asyncio.new_event_loop())
with pytest.raises(EventLoopMismatch):
token.chain(token2)


def test_token_chain_trigger_chain():
token = CancelToken("token")
token2 = CancelToken("token2")
token3 = CancelToken("token3")
intermediate_chain = token.chain(token2)
chain = intermediate_chain.chain(token3)
assert not chain.triggered
chain.trigger()
assert chain.triggered
assert not intermediate_chain.triggered
assert chain.triggered_token == chain
assert not token.triggered
assert not token2.triggered
assert not token3.triggered


def test_token_chain_trigger_first():
token = CancelToken("token")
token2 = CancelToken("token2")
token3 = CancelToken("token3")
chain = token.chain(token2).chain(token3)
assert not chain.triggered
token.trigger()
assert chain.triggered
assert chain.triggered_token == token


def test_token_chain_trigger_middle():
token = CancelToken("token")
token2 = CancelToken("token2")
token3 = CancelToken("token3")
intermediate_chain = token.chain(token2)
chain = intermediate_chain.chain(token3)
assert not chain.triggered
token2.trigger()
assert chain.triggered
assert intermediate_chain.triggered
assert chain.triggered_token == token2
assert not token3.triggered
assert not token.triggered


def test_token_chain_trigger_last():
token = CancelToken("token")
token2 = CancelToken("token2")
token3 = CancelToken("token3")
intermediate_chain = token.chain(token2)
chain = intermediate_chain.chain(token3)
assert not chain.triggered
token3.trigger()
assert chain.triggered
assert chain.triggered_token == token3
assert not intermediate_chain.triggered


@pytest.mark.asyncio
async def test_token_wait(event_loop):
token = CancelToken("token")
event_loop.call_soon(token.trigger)
done, pending = await asyncio.wait([token.wait()], timeout=0.1)
assert len(done) == 1
assert len(pending) == 0
assert token.triggered


@pytest.mark.asyncio
async def test_wait_cancel_pending_tasks_on_completion(event_loop):
token = CancelToken("token")
token2 = CancelToken("token2")
chain = token.chain(token2)
event_loop.call_soon(token2.trigger)
await chain.wait()
await assert_only_current_task_not_done()


@pytest.mark.asyncio
async def test_wait_cancel_pending_tasks_on_cancellation(event_loop):
"""Test that cancelling a pending CancelToken.wait() coroutine doesn't leave .wait()
coroutines for any chained tokens behind.
"""
token = (
CancelToken("token").chain(CancelToken("token2")).chain(CancelToken("token3"))
)
token_wait_coroutine = token.wait()
done, pending = await asyncio.wait([token_wait_coroutine], timeout=0.1)
assert len(done) == 0
assert len(pending) == 1
pending_task = pending.pop()
assert pending_task._coro == token_wait_coroutine
pending_task.cancel()
await assert_only_current_task_not_done()


@pytest.mark.asyncio
async def test_cancellable_wait(event_loop):
fut = asyncio.Future()
event_loop.call_soon(functools.partial(fut.set_result, "result"))
result = await CancelToken("token").cancellable_wait(fut, timeout=1)
assert result == "result"
await assert_only_current_task_not_done()


@pytest.mark.asyncio
async def test_cancellable_wait_future_exception(event_loop):
fut = asyncio.Future()
event_loop.call_soon(functools.partial(fut.set_exception, Exception()))
with pytest.raises(Exception):
await CancelToken("token").cancellable_wait(fut, timeout=1)
await assert_only_current_task_not_done()


@pytest.mark.asyncio
async def test_cancellable_wait_cancels_subtasks_when_cancelled(event_loop):
token = CancelToken("")
future = asyncio.ensure_future(token.cancellable_wait(asyncio.sleep(2)))
with pytest.raises(asyncio.TimeoutError):
# asyncio.wait_for() will timeout and then cancel our cancellable_wait() future, but
# Task.cancel() doesn't immediately cancels the task
# (https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel), so we need
# the sleep below before we check that the task is actually cancelled.
await asyncio.wait_for(future, timeout=0.01)
await asyncio.sleep(0)
assert future.cancelled()
await assert_only_current_task_not_done()


@pytest.mark.asyncio
async def test_cancellable_wait_timeout():
with pytest.raises(TimeoutError):
await CancelToken("token").cancellable_wait(asyncio.sleep(0.02), timeout=0.01)
await assert_only_current_task_not_done()


@pytest.mark.asyncio
async def test_cancellable_wait_operation_cancelled(event_loop):
token = CancelToken("token")
token.trigger()
with pytest.raises(OperationCancelled):
await token.cancellable_wait(asyncio.sleep(0.02))
await assert_only_current_task_not_done()


async def assert_only_current_task_not_done():
# This sleep() is necessary because Task.cancel() doesn't immediately cancels the task:
# https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.cancel
await asyncio.sleep(0.01)
for task in asyncio.Task.all_tasks():
if task == asyncio.Task.current_task():
# This is the task for this very test, so it will be running
assert not task.done()
else:
assert task.done()
152 changes: 152 additions & 0 deletions quarkchain/p2p/cancel_token/token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import asyncio
from typing import Any, Awaitable, List, Sequence, TypeVar, cast # noqa: F401

from .exceptions import EventLoopMismatch, OperationCancelled

_R = TypeVar("_R")


class CancelToken:
def __init__(self, name: str, loop: asyncio.AbstractEventLoop = None) -> None:
self.name = name
self._chain = [] # : List['CancelToken']
self._triggered = asyncio.Event(loop=loop)
self._loop = loop

@property
def loop(self) -> asyncio.AbstractEventLoop:
"""
Return the `loop` that this token is bound to.
"""
return self._loop

def chain(self, token: "CancelToken") -> "CancelToken":
"""
Return a new CancelToken chaining this and the given token.
The new CancelToken's triggered will return True if trigger() has been
called on either of the chained tokens, but calling trigger() on the new token
has no effect on either of the chained tokens.
"""
if self.loop != token._loop:
raise EventLoopMismatch(
"Chained CancelToken objects must be on the same event loop"
)
chain_name = ":".join([self.name, token.name])
chain = CancelToken(chain_name, loop=self.loop)
chain._chain.extend([self, token])
return chain

def trigger(self) -> None:
"""
Trigger this cancel token and any child tokens that have been chained with it.
"""
self._triggered.set()

@property
def triggered_token(self) -> "CancelToken":
"""
Return the token which was triggered.
The returned token may be this token or one that it was chained with.
"""
if self._triggered.is_set():
return self
for token in self._chain:
if token.triggered:
# Use token.triggered_token here to make the lookup recursive as self._chain may
# contain other chains.
return token.triggered_token
return None

@property
def triggered(self) -> bool:
"""
Return `True` or `False` whether this token has been triggered.
"""
if self._triggered.is_set():
return True
return any(token.triggered for token in self._chain)

def raise_if_triggered(self) -> None:
"""
Raise `OperationCancelled` if this token has been triggered.
"""
if self.triggered:
raise OperationCancelled(
"Cancellation requested by {} token".format(self.triggered_token)
)

async def wait(self) -> None:
"""
Coroutine which returns when this token has been triggered
"""
if self.triggered_token is not None:
return

futures = [asyncio.ensure_future(self._triggered.wait(), loop=self.loop)]
for token in self._chain:
futures.append(asyncio.ensure_future(token.wait(), loop=self.loop))

def cancel_not_done(fut: "asyncio.Future[None]") -> None:
for future in futures:
if not future.done():
future.cancel()

async def _wait_for_first(futures: Sequence[Awaitable[Any]]) -> None:
for future in asyncio.as_completed(futures):
# We don't need to catch CancelledError here (and cancel not done futures)
# because our callback (above) takes care of that.
await cast(Awaitable[Any], future)
return

fut = asyncio.ensure_future(_wait_for_first(futures), loop=self.loop)
fut.add_done_callback(cancel_not_done)
await fut

async def cancellable_wait(
self, *awaitables: Awaitable[_R], timeout: float = None
) -> _R:
"""
Wait for the first awaitable to complete, unless we timeout or the
token is triggered.
Returns the result of the first awaitable to complete.
Raises TimeoutError if we timeout or
`~cancel_token.exceptions.OperationCancelled` if the cancel token is
triggered.
All pending futures are cancelled before returning.
"""
futures = [
asyncio.ensure_future(a, loop=self.loop)
for a in awaitables + (self.wait(),)
]
try:
done, pending = await asyncio.wait(
futures,
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED,
loop=self.loop,
)
except asyncio.futures.CancelledError:
# Since we use return_when=asyncio.FIRST_COMPLETED above, we can be sure none of our
# futures will be done here, so we don't need to check if any is done before cancelling.
for future in futures:
future.cancel()
raise
for task in pending:
task.cancel()
if not done:
raise TimeoutError()
if self.triggered_token is not None:
# We've been asked to cancel so we don't care about our future, but we must
# consume its exception or else asyncio will emit warnings.
for task in done:
task.exception()
raise OperationCancelled(
"Cancellation requested by {} token".format(self.triggered_token)
)
return done.pop().result()

def __str__(self) -> str:
return self.name

def __repr__(self) -> str:
return "<CancelToken: {0}>".format(self.name)
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ eth-bloom==1.0.0
pyethash>=0.1.27,<1.0.0
py_ecc==1.4.3
eth-hash[pycryptodome]==0.1.4

# p2p
pytest>=3.6,<3.7
pytest-asyncio==0.9.0

# pyethapp/accounts.py dependency
pbkdf2
Expand Down

0 comments on commit 8fd0059

Please sign in to comment.