Skip to content

Commit

Permalink
Add support for Unix sockets Fixes spotify#1118
Browse files Browse the repository at this point in the history
This also adds support for TLS!

Simply use https://exmaple.com/ for TLS or
http+unix://%2Fvar%2Frun%2Fluigid%2Fluigid.sock/ for Unix sockets
  • Loading branch information
Thomas Grainger committed Aug 14, 2015
1 parent 116577e commit ecc591b
Show file tree
Hide file tree
Showing 12 changed files with 171 additions and 44 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ env:
- TOXENV=py27-nonhdfs
- TOXENV=py33-nonhdfs
- TOXENV=py34-nonhdfs
- TOXENV=py27-unixsockets
- TOXENV=py33-unixsockets
- TOXENV=py34-unixsockets
- TOXENV=py27-cdh
- TOXENV=py33-cdh
- TOXENV=py34-cdh
Expand Down
3 changes: 2 additions & 1 deletion luigi/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def luigid(argv=sys.argv[1:]):
parser.add_argument(u'--logdir', help=u'log directory')
parser.add_argument(u'--state-path', help=u'Pickled state file')
parser.add_argument(u'--address', help=u'Listening interface')
parser.add_argument(u'--unix-socket', help=u'Unix socket path')
parser.add_argument(u'--port', default=8082, help=u'Listening port')

opts = parser.parse_args(argv)
Expand All @@ -33,7 +34,7 @@ def luigid(argv=sys.argv[1:]):
logging.getLogger().setLevel(logging.INFO)
luigi.process.daemonize(luigi.server.run, api_port=opts.port,
address=opts.address, pidfile=opts.pidfile,
logdir=opts.logdir)
logdir=opts.logdir, unix_socket=opts.unix_socket)
else:
if opts.logdir:
logging.basicConfig(level=logging.INFO, format=luigi.process.get_log_format(),
Expand Down
20 changes: 15 additions & 5 deletions luigi/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class core(task.Config):
default=8082,
description='Port of remote scheduler api process',
config_path=dict(section='core', name='default-scheduler-port'))
scheduler_url = parameter.Parameter(
default=None,
description='Full path to remote scheduler',
config_path=dict(section='core', name='default-scheduler-url'),
)
lock_size = parameter.IntParameter(
default=1,
description="Maximum number of workers running the same command")
Expand Down Expand Up @@ -110,8 +115,8 @@ class WorkerSchedulerFactory(object):
def create_local_scheduler(self):
return scheduler.CentralPlannerScheduler(prune_on_get_work=True)

def create_remote_scheduler(self, host, port):
return rpc.RemoteScheduler(host=host, port=port)
def create_remote_scheduler(self, url):
return rpc.RemoteScheduler(url)

def create_worker(self, scheduler, worker_processes, assistant=False):
return worker.Worker(
Expand Down Expand Up @@ -157,9 +162,14 @@ def run(tasks, worker_scheduler_factory=None, override_defaults=None):
if env_params.local_scheduler:
sch = worker_scheduler_factory.create_local_scheduler()
else:
sch = worker_scheduler_factory.create_remote_scheduler(
host=env_params.scheduler_host,
port=env_params.scheduler_port)
if env_params.scheduler_url is not None:
url = env_params.scheduler_url
else:
url = 'http://{host}:{port:d}/'.format(
host=env_params.scheduler_host,
port=env_params.scheduler_port,
)
sch = worker_scheduler_factory.create_remote_scheduler(url=url)

w = worker_scheduler_factory.create_worker(
scheduler=sch, worker_processes=env_params.workers, assistant=env_params.assistant)
Expand Down
4 changes: 2 additions & 2 deletions luigi/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _server_already_running(pidfile):
return False


def daemonize(cmd, pidfile=None, logdir=None, api_port=8082, address=None):
def daemonize(cmd, pidfile=None, logdir=None, api_port=8082, address=None, unix_socket=None):
import daemon

logdir = logdir or "/var/log/luigi"
Expand Down Expand Up @@ -112,4 +112,4 @@ def daemonize(cmd, pidfile=None, logdir=None, api_port=8082, address=None):
return
write_pid(pidfile)

cmd(api_port=api_port, address=address)
cmd(api_port=api_port, address=address, unix_socket=unix_socket)
84 changes: 64 additions & 20 deletions luigi/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,28 @@
import socket
import time

from luigi.six.moves.urllib.parse import urlencode
from luigi.six.moves.urllib.parse import urlencode, ParseResult
from luigi.six.moves.urllib.request import urlopen
from luigi.six.moves.urllib.error import URLError

from luigi import configuration
from luigi.scheduler import PENDING, Scheduler


HAS_UNIX_SOCKETS = True
HAS_REQUESTS = True


try:
import requests_unixsockets as requests
except ImportError:
HAS_UNIX_SOCKETS = False
try:
import requests
except ImportError:
HAS_REQUESTS = False


logger = logging.getLogger('luigi-interface') # TODO: 'interface'?


Expand All @@ -42,32 +57,62 @@ def __init__(self, message, sub_exception=None):
self.sub_exception = sub_exception


class FetcherException(Exception):
def __init__(self, original_exc):
self.original_exc = original_exc


class URLLibFetcher(object):
def fetch(self, full_url, body, timeout):
try:
body = urlencode(body).encode('utf-8')
return urlopen(full_url, body, timeout).read().decode('utf-8')
except (URLError, socket.timeout) as e:
raise FetcherException(e)


class RequestsFetcher(object):
def __init__(self, session):
self.session = session

def fetch(self, full_url, body, timeout):
from requests import exceptions as requests_exceptions
try:
resp = self.session.get(full_url, data=body, timeout=timeout)
resp.raise_for_status()
return resp.text
except (requests_exceptions.RequestException) as e:
raise FetcherException(e)


class RemoteScheduler(Scheduler):
"""
Scheduler proxy object. Talks to a RemoteSchedulerResponder.
"""

def __init__(self, host='localhost', port=8082, connect_timeout=None, url_prefix=''):
self._host = host
self._port = port
self._url_prefix = url_prefix
def __init__(self, url='http://localhost:8082/', connect_timeout=None):
assert (
not (url.startswith('http+unix://') and not HAS_UNIX_SOCKETS),
'You need to install requests-unixsocket for Unix socket support.',
)

self._url = url
config = configuration.get_config()

if connect_timeout is None:
connect_timeout = config.getfloat('core', 'rpc-connect-timeout', 10.0)
self._connect_timeout = connect_timeout

if HAS_REQUESTS:
self._fetcher = RequestsFetcher(requests.Session())
else:
self._fetcher = URLLibFetcher()

def _wait(self):
time.sleep(30)

def _fetch(self, url, body, log_exceptions=True, attempts=3):

full_url = 'http://{host}:{port:d}{prefix}{url}'.format(
host=self._host,
port=self._port,
prefix=self._url_prefix,
url=url)
def _fetch(self, url_suffix, body, log_exceptions=True, attempts=3):
full_url = self._url + url_suffix
last_exception = None
attempt = 0
while attempt < attempts:
Expand All @@ -76,24 +121,23 @@ def _fetch(self, url, body, log_exceptions=True, attempts=3):
logger.info("Retrying...")
self._wait() # wait for a bit and retry
try:
response = urlopen(full_url, body, self._connect_timeout)
response = self._fetcher.fetch(full_url, body, self._connect_timeout)
break
except (URLError, socket.timeout) as e:
last_exception = e
except FetcherException as e:
last_exception = e.original_exc
if log_exceptions:
logger.exception("Failed connecting to remote scheduler %r", self._host)
logger.exception("Failed connecting to remote scheduler %r", self._url)
continue
else:
raise RPCError(
"Errors (%d attempts) when connecting to remote scheduler %r" %
(attempts, self._host),
(attempts, self._url),
last_exception
)
return response.read().decode('utf-8')
return response

def _request(self, url, data, log_exceptions=True, attempts=3):
data = {'data': json.dumps(data)}
body = urlencode(data).encode('utf-8')
body = {'data': json.dumps(data)}

page = self._fetch(url, body, log_exceptions, attempts)
result = json.loads(page)
Expand Down
17 changes: 13 additions & 4 deletions luigi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,19 +240,22 @@ def app(scheduler):
return api_app


def _init_api(scheduler, responder=None, api_port=None, address=None):
def _init_api(scheduler, responder=None, api_port=None, address=None, unix_socket=None):
if responder:
raise Exception('The "responder" argument is no longer supported')
api_app = app(scheduler)
api_sockets = tornado.netutil.bind_sockets(api_port, address=address)
if unix_socket is not None:
api_sockets = [tornado.netutil.bind_unix_socket(unix_socket)]
else:
api_sockets = tornado.netutil.bind_sockets(api_port, address=address)
server = tornado.httpserver.HTTPServer(api_app)
server.add_sockets(api_sockets)

# Return the bound socket names. Useful for connecting client in test scenarios.
return [s.getsockname() for s in api_sockets]


def run(api_port=8082, address=None, scheduler=None, responder=None):
def run(api_port=8082, address=None, unix_socket=None, scheduler=None, responder=None):
"""
Runs one instance of the API server.
"""
Expand All @@ -262,7 +265,13 @@ def run(api_port=8082, address=None, scheduler=None, responder=None):
# load scheduler state
scheduler.load()

_init_api(scheduler, responder, api_port, address)
_init_api(
scheduler=scheduler,
responder=responder,
api_port=api_port,
address=address,
unix_socket=unix_socket,
)

# prune work DAG every 60 seconds
pruner = tornado.ioloop.PeriodicCallback(scheduler.prune, 60000)
Expand Down
4 changes: 2 additions & 2 deletions test/customized_run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def __init__(self, *args, **kwargs):
def create_local_scheduler(self):
return self.scheduler

def create_remote_scheduler(self, host, port):
return CustomizedRemoteScheduler(host=host, port=port)
def create_remote_scheduler(self, url):
return CustomizedRemoteScheduler(url)

def create_worker(self, scheduler, worker_processes=None, assistant=False):
return self.worker
Expand Down
4 changes: 2 additions & 2 deletions test/db_task_history_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def test_subsecond_timestamp(self):
self.run_task(task)

task_record = six.advance_iterator(self.history.find_all_by_name('DummyTask'))
print (task_record.events)
print(task_record.events)
self.assertEqual(task_record.events[0].event_name, DONE)

def test_utc_conversion(self):
Expand All @@ -111,7 +111,7 @@ def test_utc_conversion(self):
task_record = six.advance_iterator(self.history.find_all_by_name('DummyTask'))
last_event = task_record.events[0]
try:
print (from_utc(str(last_event.ts)))
print(from_utc(str(last_event.ts)))
except ValueError:
self.fail("Failed to convert timestamp {} to UTC".format(last_event.ts))

Expand Down
2 changes: 1 addition & 1 deletion test/rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_app(self):

def setUp(self):
super(RPCTest, self).setUp()
self.sch = luigi.rpc.RemoteScheduler(port=self.get_http_port())
self.sch = luigi.rpc.RemoteScheduler(self.get_url(''))
self.sch._wait = lambda: None

@skipOnTravis('https://travis-ci.org/spotify/luigi/jobs/72276513')
Expand Down
Loading

0 comments on commit ecc591b

Please sign in to comment.