Skip to content

Commit

Permalink
Add tornado.locks.Semaphore.
Browse files Browse the repository at this point in the history
  • Loading branch information
ajdavis committed Feb 20, 2015
1 parent 502b1ad commit 10fd949
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 1 deletion.
122 changes: 121 additions & 1 deletion tornado/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@

from __future__ import absolute_import, division, print_function, with_statement

__all__ = ['Condition', 'Event']
__all__ = ['Condition', 'Event', 'Semaphore']

import collections
import contextlib

from tornado import gen, ioloop
from tornado.concurrent import Future
Expand Down Expand Up @@ -123,3 +124,122 @@ def wait(self, timeout=None):
return self._future
else:
return gen.with_timeout(timeout, self._future)


class _ContextManagerFuture(Future):
"""A Future that can be used with the "with" statement.
When a coroutine yields this Future, the return value is a context manager
that can be used like:
with (yield future):
pass
At the end of the block, the Future's exit callback is run. Used for
Lock.acquire and Semaphore.acquire.
"""
def __init__(self, wrapped, exit_callback):
super(_ContextManagerFuture, self).__init__()
gen.chain_future(wrapped, self)
self.exit_callback = exit_callback

def result(self, timeout=None):
if self.exception():
raise self.exception()

# Otherwise return a context manager that cleans up after the block.
@contextlib.contextmanager
def f():
try:
yield
finally:
self.exit_callback()
return f()


class Semaphore(object):
"""A lock that can be acquired a fixed number of times before blocking.
A Semaphore manages a counter representing the number of `.release` calls
minus the number of `.acquire` calls, plus an initial value. The `.acquire`
method blocks if necessary until it can return without making the counter
negative.
`.acquire` supports the context manager protocol:
>>> from tornado import gen, locks
>>> semaphore = locks.Semaphore()
>>> @gen.coroutine
... def f():
... with (yield semaphore.acquire()):
... assert semaphore.locked()
...
... assert not semaphore.locked()
.. note:: Unlike the standard `threading.Semaphore`, a Tornado `.Semaphore`
can tell you the current value of its `.counter`, because code in a
single-threaded Tornado application can check this value and act upon
it without fear of interruption from another thread.
"""
def __init__(self, value=1):
if value < 0:
raise ValueError('semaphore initial value must be >= 0')

self.io_loop = ioloop.IOLoop.current()
self._value = value
self._waiters = collections.deque()

def __repr__(self):
res = super(Semaphore, self).__repr__()
extra = 'locked' if self.locked() else 'unlocked,value:{0}'.format(
self._value)
if self._waiters:
extra = '{0},waiters:{1}'.format(extra, len(self._waiters))
return '<{0} [{1}]>'.format(res[1:-1], extra)

@property
def counter(self):
"""An integer, the current semaphore value."""
return self._value

def locked(self):
"""True if the semaphore cannot be acquired immediately."""
return self._value == 0

def release(self):
"""Increment `.counter` and wake one waiter."""
self._value += 1
for waiter in self._waiters:
if not waiter.done():
self._value -= 1
waiter.set_result(None)
break

def acquire(self, timeout=None):
"""Decrement `.counter`. Returns a Future.
Block if the counter is zero and wait for a `.release`. The Future
raises `.TimeoutError` after the deadline.
"""
if self._value > 0:
self._value -= 1
future = gen._null_future
else:
waiter = Future()
self._waiters.append(waiter)
if timeout:
future = gen.with_timeout(timeout, waiter, self.io_loop,
quiet_exceptions=gen.TimeoutError)

# Set waiter's exception after the deadline.
gen.chain_future(future, waiter)
else:
future = waiter
return _ContextManagerFuture(future, self.release)

def __enter__(self):
raise RuntimeError(
"Use Semaphore like 'with (yield semaphore.acquire())', not like"
" 'with semaphore'")

__exit__ = __enter__
117 changes: 117 additions & 0 deletions tornado/test/locks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,5 +216,122 @@ def test_event_wait_clear(self):
self.assertTrue(f1.done())


class SemaphoreTest(AsyncTestCase):
def test_negative_value(self):
self.assertRaises(ValueError, locks.Semaphore, value=-1)

def test_str(self):
sem = locks.Semaphore()
self.assertIn('Semaphore', str(sem))
self.assertIn('unlocked,value:1', str(sem))
sem.acquire()
self.assertIn('locked', str(sem))
self.assertNotIn('waiters', str(sem))
sem.acquire()
self.assertIn('waiters', str(sem))

def test_acquire(self):
sem = locks.Semaphore()
self.assertFalse(sem.locked())
f0 = sem.acquire()
self.assertTrue(f0.done())
self.assertTrue(sem.locked())

# Wait for release().
f1 = sem.acquire()
f2 = sem.acquire()
sem.release()
self.assertTrue(f1.done())
self.assertFalse(f2.done())
sem.release()
self.assertTrue(f2.done())

sem.release()
# Now acquire() is instant.
self.assertTrue(sem.acquire().done())

@gen_test
def test_acquire_timeout(self):
sem = locks.Semaphore(2)
yield sem.acquire()
yield sem.acquire()
with self.assertRaises(gen.TimeoutError):
yield sem.acquire(timedelta(seconds=0.01))

f = sem.acquire()
sem.release()
self.assertTrue(f.done())

def test_release_unacquired(self):
# Unbounded releases are allowed, and increment the semaphore's value.
sem = locks.Semaphore()
sem.release()
sem.release()
self.assertEqual(3, sem.counter)


class SemaphoreContextManagerTest(AsyncTestCase):
@gen_test
def test_context_manager(self):
sem = locks.Semaphore()
with (yield sem.acquire()) as yielded:
self.assertTrue(sem.locked())
self.assertTrue(yielded is None)

self.assertFalse(sem.locked())

@gen_test
def test_context_manager_exception(self):
sem = locks.Semaphore()
with self.assertRaises(ZeroDivisionError):
with (yield sem.acquire()):
1 / 0

# Context manager released semaphore.
self.assertFalse(sem.locked())

@gen_test
def test_context_manager_timeout(self):
sem = locks.Semaphore(value=0)
with self.assertRaises(gen.TimeoutError):
with (yield sem.acquire(timedelta(seconds=0.01))):
pass

@gen_test
def test_context_manager_contended(self):
sem = locks.Semaphore()
history = []

@gen.coroutine
def f(index):
with (yield sem.acquire()):
history.append('acquired %d' % index)
yield gen.sleep(0.01)
history.append('release %d' % index)

yield [f(i) for i in range(2)]

expected_history = []
for i in range(2):
expected_history.extend(['acquired %d' % i, 'release %d' % i])

self.assertEqual(expected_history, history)

@gen_test
def test_yield_sem(self):
# Ensure we catch a "with (yield sem)", which should be
# "with (yield sem.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Semaphore()):
pass

def test_context_manager_misuse(self):
# Ensure we catch a "with sem", which should be
# "with (yield sem.acquire())".
with self.assertRaises(RuntimeError):
with locks.Semaphore():
pass


if __name__ == '__main__':
unittest.main()

0 comments on commit 10fd949

Please sign in to comment.