Skip to content

Commit

Permalink
Prefork: Use poll() to avoid limitations of select() (Issue celery#2373)
Browse files Browse the repository at this point in the history
  • Loading branch information
ask committed Nov 13, 2014
1 parent 9bea083 commit 7245458
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 55 deletions.
41 changes: 32 additions & 9 deletions celery/concurrency/asynpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from collections import deque, namedtuple
from io import BytesIO
from numbers import Integral
from pickle import HIGHEST_PROTOCOL
from time import sleep
from weakref import WeakValueDictionary, ref
Expand Down Expand Up @@ -109,8 +110,11 @@ def _get_job_writer(job):
return writer() # is a weakref


def _select(readers=None, writers=None, err=None, timeout=0):
"""Simple wrapper to :class:`~select.select`.
def _select(readers=None, writers=None, err=None, timeout=0,
poll=select.poll, POLLIN=select.POLLIN,
POLLOUT=select.POLLOUT, POLLERR=select.POLLERR):
"""Simple wrapper to :class:`~select.select`, using :`~select.poll`
as the implementation.
:param readers: Set of reader fds to test if readable.
:param writers: Set of writer fds to test if writable.
Expand All @@ -131,25 +135,44 @@ def _select(readers=None, writers=None, err=None, timeout=0):
readers = set() if readers is None else readers
writers = set() if writers is None else writers
err = set() if err is None else err
poller = poll()
register = poller.register

if readers:
[register(fd, POLLIN) for fd in readers]
if writers:
[register(fd, POLLOUT) for fd in writers]
if err:
[register(fd, POLLERR) for fd in err]

R, W = set(), set()
timeout = 0 if timeout and timeout < 0 else round(timeout * 1e3)
try:
r, w, e = select.select(readers, writers, err, timeout)
if e:
r = list(set(r) | set(e))
return r, w, 0
events = poller.poll(timeout)
for fd, event in events:
if not isinstance(fd, Integral):
fd = fd.fileno()
if event & POLLIN:
R.add(fd)
if event & POLLOUT:
W.add(fd)
if event & POLLERR:
R.add(fd)
return R, W, 0
except (select.error, socket.error) as exc:
if exc.errno == errno.EINTR:
return [], [], 1
return set(), set(), 1
elif exc.errno in SELECT_BAD_FD:
for fd in readers | writers | err:
try:
select.select([fd], [], [], 0)
except (select.error, socket.error) as exc:
if exc.errno not in SELECT_BAD_FD:
if getattr(exc, 'errno', None) not in SELECT_BAD_FD:
raise
readers.discard(fd)
writers.discard(fd)
err.discard(fd)
return [], [], 1
return set(), set(), 1
else:
raise

Expand Down
109 changes: 63 additions & 46 deletions celery/tests/concurrency/test_prefork.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from __future__ import absolute_import

import errno
import select
import socket
import time

from itertools import cycle

from celery.five import items, range
from celery.utils.functional import noop
from celery.tests.case import AppCase, Mock, SkipTest, call, patch
from celery.tests.case import AppCase, Mock, SkipTest, patch
try:
from celery.concurrency import prefork as mp
from celery.concurrency import asynpool
Expand Down Expand Up @@ -147,67 +148,83 @@ def gen():
list(g)
self.assertFalse(asynpool.gen_not_started(g))

def test_select(self):
@patch('select.select', create=True)
def test_select(self, __select):
ebadf = socket.error()
ebadf.errno = errno.EBADF
with patch('select.select') as select:
select.return_value = ([3], [], [])
with patch('select.poll', create=True) as poller:
poll = poller.return_value = Mock(name='poll.poll')
poll.poll.return_value = [(3, select.POLLIN)]
self.assertEqual(
asynpool._select({3}),
([3], [], 0),
asynpool._select({3}, poll=poller),
({3}, set(), 0),
)

select.return_value = ([], [], [3])
poll.poll.return_value = [(3, select.POLLERR)]
self.assertEqual(
asynpool._select({3}, None, {3}),
([3], [], 0),
asynpool._select({3}, None, {3}, poll=poller),
({3}, set(), 0),
)

eintr = socket.error()
eintr.errno = errno.EINTR
select.side_effect = eintr
poll.poll.side_effect = eintr

readers = {3}
self.assertEqual(asynpool._select(readers), ([], [], 1))
self.assertEqual(
asynpool._select(readers, poll=poller),
(set(), set(), 1),
)
self.assertIn(3, readers)

with patch('select.select') as select:
select.side_effect = ebadf
readers = {3}
self.assertEqual(asynpool._select(readers), ([], [], 1))
select.assert_has_calls([call([3], [], [], 0)])
self.assertNotIn(3, readers)

with patch('select.select') as select:
select.side_effect = MemoryError()
with self.assertRaises(MemoryError):
asynpool._select({1})

with patch('select.select') as select:

def se(*args):
select.side_effect = MemoryError()
raise ebadf
select.side_effect = se
with patch('select.poll') as poller:
poll = poller.return_value = Mock(name='poll.poll')
poll.poll.side_effect = ebadf
with patch('select.select') as selcheck:
selcheck.side_effect = ebadf
readers = {3}
self.assertEqual(
asynpool._select(readers, poll=poller),
(set(), set(), 1),
)
self.assertNotIn(3, readers)

with patch('select.poll') as poller:
poll = poller.return_value = Mock(name='poll.poll')
poll.poll.side_effect = MemoryError()
with self.assertRaises(MemoryError):
asynpool._select({3})

with patch('select.select') as select:

def se2(*args):
select.side_effect = socket.error()
select.side_effect.errno = 1321
raise ebadf
select.side_effect = se2
with self.assertRaises(socket.error):
asynpool._select({3})

with patch('select.select') as select:

select.side_effect = socket.error()
select.side_effect.errno = 34134
asynpool._select({1}, poll=poller)

with patch('select.poll') as poller:
poll = poller.return_value = Mock(name='poll.poll')
with patch('select.select') as selcheck:

def se(*args):
selcheck.side_effect = MemoryError()
raise ebadf
poll.poll.side_effect = se
with self.assertRaises(MemoryError):
asynpool._select({3}, poll=poller)

with patch('select.poll') as poller:
poll = poller.return_value = Mock(name='poll.poll')
with patch('select.select') as selcheck:

def se2(*args):
selcheck.side_effect = socket.error()
selcheck.side_effect.errno = 1321
raise ebadf
poll.poll.side_effect = se2
with self.assertRaises(socket.error):
asynpool._select({3}, poll=poller)

with patch('select.poll') as poller:
poll = poller.return_value = Mock(name='poll.poll')

poll.poll.side_effect = socket.error()
poll.poll.side_effect.errno = 34134
with self.assertRaises(socket.error):
asynpool._select({3})
asynpool._select({3}, poll=poller)

def test_promise(self):
fun = Mock()
Expand Down

0 comments on commit 7245458

Please sign in to comment.