diff --git a/src/command.py b/src/command.py index c828c5b..cfc1e15 100644 --- a/src/command.py +++ b/src/command.py @@ -42,26 +42,16 @@ def run_python_over_tor(queue, circ_id, socks_port): def closure(func, *args): """ Route the given Python function's network traffic over Tor. - - We temporarily monkey-patch socket.socket using our torsocks module and - reset it, once the function returns. + We temporarily monkey-patch socket.socket using our torsocks + module, and reset it once the function returns. """ - - torsocks.set_default_proxy("127.0.0.1", socks_port) - torsocks.queue = queue - torsocks.circ_id = circ_id - - orig_socket = socket.socket - socket.socket = torsocks.torsocket - try: - func(*args) - except error.SOCKSv5Error as err: + with torsocks.MonkeyPatchedSocket(queue, circ_id, socks_port): + func(*args) + except (error.SOCKSv5Error, socket.error) as err: logger.info(err) return - socket.socket = orig_socket - return closure @@ -71,14 +61,13 @@ class Command(object): Provide an abstraction for a shell command which is to be run. """ - def __init__(self, queue, circ_id, origsock, socks_port): + def __init__(self, queue, circ_id, socks_port): self.process = None self.stdout = None self.stderr = None self.output_callback = None self.queue = queue - self.origsocket = origsock self.circ_id = circ_id self.socks_port = socks_port @@ -122,14 +111,7 @@ def invoke_process(self, command): port = util.extract_pattern(line, pattern) if port: - - # socket.socket is probably monkey-patched. We need, - # however, the original implementation. - - tmpsock = socket.socket - socket.socket = self.origsocket self.queue.put([self.circ_id, ("127.0.0.1", int(port))]) - socket.socket = tmpsock keep_reading = self.output_callback(line, self.process.kill) diff --git a/src/eventhandler.py b/src/eventhandler.py index 7828309..7864107 100644 --- a/src/eventhandler.py +++ b/src/eventhandler.py @@ -121,7 +121,7 @@ def _attach(self, stream_id=None, circuit_id=None): logger.warning("Failed to attach stream because: %s" % err) -def module_closure(queue, module, circ_id, *module_args): +def module_closure(queue, module, circ_id, *module_args, **module_kwargs): """ Return function that runs the module and then informs event handler. """ @@ -135,7 +135,7 @@ def func(): """ try: - module(*module_args) + module(*module_args, **module_kwargs) logger.debug("Informing event handler that module finished.") queue.put((circ_id, None)) @@ -155,7 +155,7 @@ class EventHandler(object): new streams unattached. """ - def __init__(self, controller, module, socks_port, stats): + def __init__(self, controller, module, socks_port, stats, exit_destinations): self.stats = stats self.controller = controller @@ -164,6 +164,9 @@ def __init__(self, controller, module, socks_port, stats): self.manager = multiprocessing.Manager() self.queue = self.manager.Queue() self.socks_port = socks_port + self.exit_destinations = exit_destinations + self.check_finished_lock = threading.Lock() + self.already_finished = False queue_thread = threading.Thread(target=self.queue_reader) queue_thread.daemon = False @@ -212,37 +215,43 @@ def check_finished(self): Check if the scan is finished and if it is, shut down exitmap. """ - # Did all circuits either build or fail? + # This is called from both the queue reader thread and the + # main thread, but (if it detects completion) does things that + # must only happen once. + with self.check_finished_lock: + if self.already_finished: + sys.exit(0) - circs_done = ((self.stats.failed_circuits + - self.stats.successful_circuits) == - self.stats.total_circuits) + # Did all circuits either build or fail? + circs_done = ((self.stats.failed_circuits + + self.stats.successful_circuits) == + self.stats.total_circuits) - # Was every built circuit attached to a stream? + # Was every built circuit attached to a stream? + streams_done = (self.stats.finished_streams >= + (self.stats.successful_circuits - + self.stats.failed_circuits)) - streams_done = (self.stats.finished_streams >= - (self.stats.successful_circuits - - self.stats.failed_circuits)) + logger.debug("failedCircs=%d, builtCircs=%d, totalCircs=%d, " + "finishedStreams=%d" % ( + self.stats.failed_circuits, + self.stats.successful_circuits, + self.stats.total_circuits, + self.stats.finished_streams)) - logger.debug("failedCircs=%d, builtCircs=%d, totalCircs=%d, " - "finishedStreams=%d" % ( - self.stats.failed_circuits, - self.stats.successful_circuits, - self.stats.total_circuits, - self.stats.finished_streams)) + if circs_done and streams_done: + self.already_finished = True - if circs_done and streams_done: + for proc in multiprocessing.active_children(): + logger.debug("Terminating remaining PID %d." % proc.pid) + proc.terminate() - for proc in multiprocessing.active_children(): - logger.debug("Terminating remaining PID %d." % proc.pid) - proc.terminate() + if hasattr(self.module, "teardown"): + logger.debug("Calling module's teardown() function.") + self.module.teardown() - if hasattr(self.module, "teardown"): - logger.debug("Calling module's teardown() function.") - self.module.teardown() - - logger.info(self.stats) - sys.exit(0) + logger.info(self.stats) + sys.exit(0) def new_circuit(self, circ_event): """ @@ -262,7 +271,6 @@ def new_circuit(self, circ_event): run_cmd_over_tor = command.Command(self.queue, circ_event.id, - socket.socket, self.socks_port) exit_desc = get_relay_desc(self.controller, exit_fpr) @@ -275,7 +283,8 @@ def new_circuit(self, circ_event): command.run_python_over_tor(self.queue, circ_event.id, self.socks_port), - run_cmd_over_tor) + run_cmd_over_tor, + destinations=self.exit_destinations[exit_fpr]) proc = multiprocessing.Process(target=module) proc.daemon = True diff --git a/src/exitmap.py b/src/exitmap.py index 48eee4b..f6a3782 100644 --- a/src/exitmap.py +++ b/src/exitmap.py @@ -262,6 +262,22 @@ def main(): logger.error("Failed to run because : %s" % err) return 0 +def lookup_destinations(module): + """ + Determine the set of destinations that the module might like to scan. + This removes redundancies and reduces all hostnames to IP addresses. + """ + destinations = set() + addrs = {} + if hasattr(module, 'destinations'): + raw_destinations = module.destinations + if raw_destinations is not None: + for (host, port) in raw_destinations: + if host not in addrs: + addrs[host] = socket.gethostbyname(host) + destinations.add((addrs[host], port)) + + return destinations def select_exits(args, module): """ @@ -273,50 +289,37 @@ def select_exits(args, module): """ before = datetime.datetime.now() - hosts = [] - - if module.destinations is not None: - hosts = [(socket.gethostbyname(host), port) for - (host, port) in module.destinations] + destinations = lookup_destinations(module) if args.exit: # '-e' was used to specify a single exit relay. - - exit_relays = [args.exit] - total = len(exit_relays) + requested_exits = [args.exit] elif args.exit_file: - # '-E' was used to specify a file containing exit relays - + # '-E' was used to specify a file containing exit relays. try: - exit_relays = [line.strip() for line in open(args.exit_file)] - total = len(exit_relays) + requested_exits = [line.strip() for line in open(args.exit_file)] + except OSError as err: + logger.error("Could not read %s: %s", args.exit_file, + err.strerror) + sys.exit(1) except Exception as err: - logger.error("Could not read file %s", args.exit_file) + logger.error("Could not read %s: %s", args.exit_file, err) sys.exit(1) else: - good_exits = False if (args.all_exits or args.bad_exits) else True - total, exit_relays = relayselector.get_exits(args.tor_dir, - country_code=args.country, - bad_exit=args.bad_exits, - good_exit=good_exits, - hosts=hosts) + requested_exits = None + + exit_destinations = relayselector.get_exits( + args.tor_dir, + good_exit = args.all_exits or (not args.bad_exits), + bad_exit = args.all_exits or args.bad_exits, + country_code = args.country, + requested_exits = requested_exits, + destinations = destinations) logger.debug("Successfully selected exit relays after %s." % str(datetime.datetime.now() - before)) - pretty_hosts = ["%s:%d" % (host, port) for host, port in hosts] - logger.info("%d%s exit relays out of all %s exit relays allow traffic " - "to: %s" % (len(exit_relays), - " %s" % args.country if args.country else "", - total, - ", ".join(pretty_hosts))) - - assert isinstance(exit_relays, list) - - random.shuffle(exit_relays) - - return exit_relays - + return exit_destinations def run_module(module_name, args, controller, socks_port, stats): """ @@ -338,7 +341,10 @@ def run_module(module_name, args, controller, socks_port, stats): logger.debug("Calling module's setup() function.") module.setup() - exit_relays = select_exits(args, module) + exit_destinations = select_exits(args, module) + + exit_relays = list(exit_destinations.keys()) + random.shuffle(exit_relays) count = len(exit_relays) stats.total_circuits += count @@ -347,7 +353,9 @@ def run_module(module_name, args, controller, socks_port, stats): raise error.ExitSelectionError("Exit selection yielded %d exits " "but need at least one." % count) - handler = EventHandler(controller, module, socks_port, stats) + handler = EventHandler(controller, module, socks_port, stats, + exit_destinations=exit_destinations) + controller.add_event_listener(handler.new_event, EventType.CIRC, EventType.STREAM) @@ -355,7 +363,7 @@ def run_module(module_name, args, controller, socks_port, stats): logger.info("Scan is estimated to take around %s." % datetime.timedelta(seconds=duration)) - logger.debug("Beginning to trigger %d circuit creation(s)." % count) + logger.info("Beginning to trigger %d circuit creation(s)." % count) iter_exit_relays(exit_relays, controller, stats, args) @@ -370,8 +378,6 @@ def iter_exit_relays(exit_relays, controller, stats, args): fingerprints = relayselector.get_fingerprints(cached_consensus_path) count = len(exit_relays) - logger.info("Beginning to trigger circuit creations.") - # Start building a circuit for every exit relay we got. for i, exit_relay in enumerate(exit_relays): diff --git a/src/modules/checktest.py b/src/modules/checktest.py index 4532460..16d92de 100644 --- a/src/modules/checktest.py +++ b/src/modules/checktest.py @@ -73,7 +73,7 @@ def fetch_page(exit_desc): logger.debug("Exit relay %s passed the check test." % url) -def probe(exit_desc, run_python_over_tor, run_cmd_over_tor): +def probe(exit_desc, run_python_over_tor, run_cmd_over_tor, **kwargs): """ Probe the given exit relay and look for check.tp.o false negatives. """ diff --git a/src/modules/cloudflared.py b/src/modules/cloudflared.py index 626d48e..f91ffe8 100644 --- a/src/modules/cloudflared.py +++ b/src/modules/cloudflared.py @@ -92,7 +92,7 @@ def is_cloudflared(exit_fpr): logger.info("Exit %s does not see a CAPTCHA." % exit_url) -def probe(exit_desc, run_python_over_tor, run_cmd_over_tor): +def probe(exit_desc, run_python_over_tor, run_cmd_over_tor, **kwargs): """ Check if exit relay sees a CloudFlare CAPTCHA. """ diff --git a/src/modules/dnspoison.py b/src/modules/dnspoison.py index a8c21bb..f180f90 100644 --- a/src/modules/dnspoison.py +++ b/src/modules/dnspoison.py @@ -93,7 +93,7 @@ def resolve(exit_desc, domain, whitelist): (domain, exit)) -def probe(exit_desc, run_python_over_tor, run_cmd_over_tor): +def probe(exit_desc, run_python_over_tor, run_cmd_over_tor, **kwargs): """ Probe the given exit relay and check if all domains resolve as expected. """ diff --git a/src/modules/dnssec.py b/src/modules/dnssec.py index 05b9686..69eb953 100644 --- a/src/modules/dnssec.py +++ b/src/modules/dnssec.py @@ -67,7 +67,7 @@ def test_dnssec(exit_fpr): logger.critical("%s resolved domain to %s" % (exit_url, ipv4)) -def probe(exit_desc, run_python_over_tor, run_cmd_over_tor): +def probe(exit_desc, run_python_over_tor, run_cmd_over_tor, **kwargs): """ Test if exit relay can resolve broken domain. """ diff --git a/src/modules/patchingCheck.py b/src/modules/patchingCheck.py index faa16e1..367b5d8 100644 --- a/src/modules/patchingCheck.py +++ b/src/modules/patchingCheck.py @@ -206,7 +206,7 @@ def run_check(exit_desc): os.remove(tmp_file) -def probe(exit_desc, run_python_over_tor, run_cmd_over_tor): +def probe(exit_desc, run_python_over_tor, run_cmd_over_tor, **kwargs): """ Probe the given exit relay and look for modified binaries. """ diff --git a/src/modules/rtt.py b/src/modules/rtt.py new file mode 100644 index 0000000..e8791f6 --- /dev/null +++ b/src/modules/rtt.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python2 + +# Copyright 2013-2016 Philipp Winter +# Copyright 2016 Zack Weinberg +# +# This file is part of exitmap. +# +# exitmap is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# exitmap is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with exitmap. If not, see . + +""" +Module to measure round-trip times through an exit to various +destinations. Each destination will receive ten TCP connections from +each scanned exit, no faster than one connection every 50ms. The module +doesn't care whether it gets a SYN/ACK or a RST in response -- either +way, the round-trip time is recorded and the connection is dropped. + +Connections are attempted to one of port 53, 22, 443 or 80, depending +on what's allowed by the exit's policy. + +Until modules can take command-line arguments, the destinations should +be specified in a text file named "rtt-destinations.txt", one IP +address per line. (You _may_ use hostnames, but if you do, they will +be resolved directly, not via Tor.) +""" + +## Configuration parameters: +# The set of ports that we consider connecting to. +PREFERRED_PORT_ORDER = (53, 22, 443, 80) + +# The total number of connections to make to each host. +CONNECTIONS_PER_HOST = 10 + +# The number of hosts to connect to in parallel. Note that we will +# _not_ connect to any one host more than once at a time. +PARALLEL_CONNECTIONS = 4 + +# The delay between successive connections (seconds) +CONNECTION_SPACING = 0.25 + +# The per-connection timeout (seconds). +CONNECTION_TIMEOUT = 10.0 + +import sys +import os + +# We don't _need_ the top-level exitmap module, but this is the most +# reliable way to figure out whether we need to add the directory with +# the utility modules that we _do_ need to sys.path. +try: + import exitmap +except ImportError: + current_path = os.path.dirname(__file__) + src_path = os.path.abspath(os.path.join(current_path, "..")) + sys.path.insert(0, src_path) + import exitmap + +import csv +import errno +import random +import socket + +try: + from time import monotonic as tick +except ImportError: + # FIXME: Maybe use ctypes to get at clock_gettime(CLOCK_MONOTONIC)? + from time import time as tick + +try: + import selectors +except ImportError: + import selectors34 as selectors + +import log +logger = log.get_logger() +def progress(total, pending, complete): + logger.info("{:>6}/{:>6} complete, {} pending" + .format(complete, total, pending)) + +import util + +def perform_probes(addresses, spacing, parallel, timeout, wr): + """Make a TCP connection to each of the ADDRESSES, in order, and + measure the time for connect(2) to either succeed or fail -- we + don't care which. Each element of the iterable ADDRESSES should + be an AF_INET address 2-tuple (i.e. ('a.b.c.d', n)). Successive + connections will be no closer to each other in time than SPACING + floating-point seconds. No more than PARALLEL concurrent + connections will occur at any one time. Sockets that have neither + succeeded nor failed to connect after TIMEOUT floating-point + seconds will be treated as having failed. No data is transmitted; + each socket is closed immediately after the connection resolves. + + The results are written to the csv.writer object WR; each row of the + file will be ,,. + """ + + if timeout <= 0: + raise ValueError("timeout must be positive") + if spacing <= 0: + raise ValueError("spacing must be positive") + if parallel < 1: + raise ValueError("parallel must be at least 1") + + sel = selectors.DefaultSelector() + EVENT_READ = selectors.EVENT_READ + AF_INET = socket.AF_INET + SOCK_STREAM = socket.SOCK_STREAM + + EINPROGRESS = errno.EINPROGRESS + CONN_RESOLVED = (0, + errno.ECONNREFUSED, + errno.EHOSTUNREACH, + errno.ENETUNREACH, + errno.ETIMEDOUT, + errno.ECONNRESET) + + pending = set() + addresses.reverse() + last_connection = 0 + last_progress = 0 + total = len(addresses) + complete = 0 + change = False + + try: + while pending or addresses: + now = tick() + if change or now - last_progress > 10: + progress(total, len(pending), complete) + last_progress = now + change = False + + if (len(pending) < parallel and addresses + and now - last_connection >= spacing): + + addr = addresses.pop() + sock = socket.socket(AF_INET, SOCK_STREAM) + sock.setblocking(False) + + last_connection = tick() + err = sock.connect_ex(addr) + logger.debug("Socket %d connecting to %r returned %d/%s", + sock.fileno(), addr, err, os.strerror(err)) + if err == EINPROGRESS: + # This is the expected case: the connection attempt is + # in progress and we must wait for results. + pending.add(sel.register(sock, EVENT_READ, + (addr, last_connection))) + change = True + + elif err in CONN_RESOLVED: + # The connection attempt resolved before connect() + # returned. + after = tick() + sock.close() + wr.writerow((addr[0], addr[1], after - now)) + complete += 1 + change = True + + else: + # Something dire has happened and we probably + # can't continue (for instance, there's no local + # network connection). + exc = socket.error(err, os.strerror(err)) + exc.filename = '%s:%d' % addr + raise exc + + events = sel.select(spacing) + after = tick() + # We don't care whether each connection succeeded or failed. + for key, _ in events: + addr, before = key.data + sock = key.fileobj + logger.debug("Socket %d connecting to %r resolved", + sock.fileno(), addr) + + sel.unregister(sock) + sock.close() + pending.remove(key) + wr.writerow((addr[0], addr[1], after - before)) + complete += 1 + change = True + + # Check for timeouts. + for key in list(pending): + addr, before = key.data + if after - before >= timeout: + sock = key.fileobj + logger.debug("Socket %d connecting to %r timed out", + sock.fileno(), addr) + sel.unregister(sock) + sock.close() + pending.remove(key) + wr.writerow((addr[0], addr[1], after - before)) + complete += 1 + change = True + + #end while + progress(total, len(pending), complete) + + finally: + for key in pending: + sel.unregister(key.fileobj) + key.fileobj.close() + sel.close() + + +def choose_probe_order(dests): + """Choose a randomized probe order for the destinations DESTS, which is + a set of (host, port) pairs. The return value is a list acceptable + as the ADDRESSES argument to perform_probes.""" + + hosts = {} + for h,p in dests: + if h not in hosts: hosts[h] = set() + hosts[h].add(p) + + remaining = {} + last_appearance = {} + full_address = {} + for host, usable_ports in hosts.iteritems(): + for p in PREFERRED_PORT_ORDER: + if p in usable_ports: + full_address[host] = (host, p) + remaining[host] = CONNECTIONS_PER_HOST + last_appearance[host] = -1 + + rv = [] + deadcycles = 0 + while remaining: + ks = remaining.keys() + x = random.choice(ks) + last = last_appearance[x] + if last == -1 or (len(rv) - last) >= (len(ks) // 4): + last_appearance[x] = len(rv) + rv.append(full_address[x]) + remaining[x] -= 1 + if not remaining[x]: + del remaining[x] + deadcycles = 0 + else: + deadcycles += 1 + if deadcycles == 10: + raise RuntimeError("choose_probe_order: 10 dead cycles\n" + "remaining: %r\n" + "last_appearance: %r\n" + % (remaining, last_appearance)) + return rv + + +def probe(exit_desc, run_python_over_tor, run_cmd_over_tor, + destinations, **kwargs): + """ + Probe the given exit relay. + """ + addresses = choose_probe_order(destinations) + + try: + os.makedirs(util.analysis_dir) + except OSError as err: + if err.errno != errno.EEXIST: + raise + + with open(os.path.join(util.analysis_dir, + exit_desc.fingerprint + ".csv"), "wt") as f: + wr = csv.writer(f, quoting=csv.QUOTE_MINIMAL, lineterminator='\n') + wr.writerow(("host","port","elapsed")) + + run_python_over_tor(perform_probes, + addresses, + CONNECTION_SPACING, + PARALLEL_CONNECTIONS, + CONNECTION_TIMEOUT, + wr) + +# exitmap needs this variable to figure out which relays can exit to the given +# destination(s). + +destinations = None +def setup(): + ds = set() + with open("rtt-destinations.txt") as f: + for line in f: + line = line.strip() + if not line or line[0] == '#': continue + ipaddr = socket.getaddrinfo( + line, 80, socket.AF_INET, socket.SOCK_STREAM, 0, 0)[0][4][0] + + for p in PREFERRED_PORT_ORDER: + ds.add((ipaddr, p)) + + global destinations + destinations = sorted(ds) diff --git a/src/modules/testfds.py b/src/modules/testfds.py index a710776..7a242fb 100644 --- a/src/modules/testfds.py +++ b/src/modules/testfds.py @@ -70,7 +70,7 @@ def fetch_page(exit_desc): logger.debug("Exit relay %s worked fine." % exit_url) -def probe(exit_desc, run_python_over_tor, run_cmd_over_tor): +def probe(exit_desc, run_python_over_tor, run_cmd_over_tor, **kwargs): """ Attempts to fetch a small web page and yells if this fails. """ diff --git a/src/relayselector.py b/src/relayselector.py index b06a3d7..3dca1a4 100755 --- a/src/relayselector.py +++ b/src/relayselector.py @@ -82,125 +82,201 @@ def get_fingerprints(cached_consensus_path, exclude=[]): return fingerprints - -def get_exits(data_dir, country_code=None, bad_exit=False, good_exit=False, - version=None, nickname=None, address=None, hosts=[]): - """ - Extract exit relays with given attributes from consensus. - - Attempts to get the consensus from the provided data directory and extracts - all relays with the given attributes. - """ - - assert not (bad_exit and good_exit) - - cached_consensus = {} - have_exit_policy = {} - have_exit_flag = {} - - cached_consensus_path = os.path.join(data_dir, "cached-consensus") - cached_descriptors_path = os.path.join(data_dir, "cached-descriptors") - - # First, read the file "cached_descriptors" in order to get the full exit - # policy of all relays instead of just the summary which might be - # insufficient. +def get_exit_policies(cached_descriptors_path): + """Read all relays' full exit policies from "cached_descriptors".""" try: + have_exit_policy = {} # We don't validate to work around the following issue: # - for desc in stem.descriptor.parse_file(cached_descriptors_path, validate=False): if desc.exit_policy.is_exiting_allowed(): have_exit_policy[desc.fingerprint] = desc + + return have_exit_policy + except IOError as err: logger.critical("File \"%s\" could not be read: %s" % (cached_descriptors_path, err)) sys.exit(1) - # Now, also read the file "cached_consensus" to see which relays got the - # "Exit" flag from the directory authorities. +def get_cached_consensus(cached_consensus_path): + """Read relays' summarized descriptors from "cached_consensus".""" try: + cached_consensus = {} for desc in stem.descriptor.parse_file(cached_consensus_path): cached_consensus[desc.fingerprint] = desc - if stem.Flag.EXIT in desc.flags: - have_exit_flag[desc.fingerprint] = desc + return cached_consensus + except IOError as err: logger.critical("File \"%s\" could not be read: %s" % (cached_descriptors_path, err)) sys.exit(1) - # Drop all exit relays for which we have a descriptor but which did not - # make it into the consensus. - - have_exit_policy = {fpr: desc for fpr, desc in have_exit_policy.iteritems() - if fpr in cached_consensus} - exit_candidates = list(have_exit_policy.values()) - - set_diff = set(have_exit_policy.keys()) - set(have_exit_flag.keys()) - logger.info("%d relays have non-empty exit policy but no exit flag." % - len(set_diff)) - - if hosts: - def can_exit_to(desc): - for (ip_addr, port) in hosts: - - # Use the full exit policy for the given descriptor. - - desc = have_exit_policy.get(desc.fingerprint, None) - assert desc - if not desc.exit_policy.can_exit_to(ip_addr, port): - return False - - return True - - exit_candidates = filter(can_exit_to, exit_candidates) +def get_exits(data_dir, + good_exit=True, bad_exit=False, + version=None, nickname=None, address=None, country_code=None, + requested_exits=None, destinations=None): + """Load the Tor network consensus from DATA_DIR, and extract all exit + relays that have the desired set of attributes. Specifically: + + - requested_exits: If not None, must be a list of fingerprints, + and only those relays will be included in the results. + + - country_code, version, nickname, address: + If not None, only relays with the specified attributes + will be included in the results. + + - bad_exit, good_exit: If True, the respective type of exit will + be included. At least one should be True, or else the results + will be empty. + + These combine as follows: + + exit.fingerprint IN requested_exits + AND exit.country_code == country_code + AND exit.version == version + AND exit.nickname IN nickname + AND exit.address IN address + AND ( (bad_exit AND exit.is_bad_exit) + OR (good_exit AND NOT exit.is_bad_exit)) + + In all cases, the criterion is skipped if the argument is None. + + Finally, 'destinations' is considered. If this is None, all + results from the above filter expression are returned. Otherwise, + 'destinations' must be a set of (host, port) pairs, and only exits + that will connect to *some* of these destinations will be included + in the results. + + Returns a dictionary, whose keys are the selected relays' fingerprints. + The value for each fingerprint is a set of (host, port) pairs that + that exit is willing to connect to; this is always a subset of the + input 'destinations' set. (If 'destinations' was None, each value + is a pseudo-set object for which '(host, port) in s' always + returns True.) + """ - if address: - exit_candidates = filter(lambda desc: address in desc.address, - exit_candidates) - if nickname: - exit_candidates = filter(lambda desc: nickname in desc.nickname, - exit_candidates) + cached_consensus_path = os.path.join(data_dir, "cached-consensus") + cached_descriptors_path = os.path.join(data_dir, "cached-descriptors") - if bad_exit: - exit_candidates = filter(lambda desc: stem.Flag.BADEXIT in - cached_consensus[desc.fingerprint].flags, - exit_candidates) + cached_consensus = get_cached_consensus(cached_consensus_path) + have_exit_policy = get_exit_policies(cached_descriptors_path) + + # Drop all exit relays which have a descriptor, but either did not + # make it into the consensus at all, or are not marked as exits there. + class StubDesc(object): + def __init__(self): + self.flags = frozenset() + stub_desc = StubDesc() + + exit_candidates = [ + desc + for fpr, desc in have_exit_policy.iteritems() + if stem.Flag.EXIT in cached_consensus.get(fpr, stub_desc).flags + ] + + logger.info("%d relays have non-empty exit policy but no exit flag.", + len(have_exit_policy) - len(exit_candidates)) + if not exit_candidates: + logger.warning("No relays have both a non-empty exit policy and an " + "exit flag. This probably means the cached network " + "consensus is invalid.") + return {} + + if bad_exit and good_exit: + pass # All exits are either bad or good. + elif bad_exit: + exit_candidates = [ + desc for desc in exit_candidates + if stem.Flag.BADEXIT in cached_consensus[desc.fingerprint].flags + ] + if not exit_candidates: + logger.warning("There are no bad exits in the current consensus.") + return {} elif good_exit: - exit_candidates = filter(lambda desc: stem.Flag.BADEXIT not in - cached_consensus[desc.fingerprint].flags, - exit_candidates) - - if version: - exit_candidates = filter(lambda desc: str(desc.tor_version) == version, - exit_candidates) + exit_candidates = [ + desc for desc in exit_candidates + if stem.Flag.BADEXIT not in cached_consensus[desc.fingerprint].flags + ] + if not exit_candidates: + logger.warning("There are no good exits in the current consensus.") + return {} + else: + # This was probably a programming error. + logger.warning("get_exits() called with bad_exits=False and " + "good_exits=False; this always returns zero exits") + return {} + + # Filter conditions are checked from cheapest to most expensive. + if address or nickname or version or requested_exits: + exit_candidates = [ + desc for desc in exit_candidates + if ((not address or address in desc.address) and + (not nickname or nickname in desc.nickname) and + (not version or version == str(desc.tor_version)) and + (not requested_exits or desc.fingerprint in requested_exits)) + ] + if not exit_candidates: + logger.warning("No exit relays meet basic filter conditions.") + return {} if country_code: - - # Get fingerprint of all relays in desired country. - - relay_fprs = util.get_relays_in_country(country_code) - - all_exit_fprs = [desc.fingerprint for desc in exit_candidates] - exit_fprs = filter(lambda fpr: fpr in all_exit_fprs, relay_fprs) - return len(exit_fprs), exit_fprs - - return (len(have_exit_policy), - [desc.fingerprint for desc in exit_candidates]) + relay_fprs = frozenset(util.get_relays_in_country(country_code)) + exit_candidates = [ + desc for desc in exit_candidates + if desc.fingerprint in relay_fprs + ] + if not exit_candidates: + logger.warning("No exit relays meet country-code filter condition.") + return {} + + if not destinations: + class UniversalSet(object): + """A universal set contains everything, but cannot be enumerated. + + If the caller of get_exits does not specify destinations, + its return value maps all fingerprints to a universal set, + so that it can still fulfill the contract of returning a + dictionary of the form { fingerprint : set(...) }. + """ + def __nonzero__(self): return True + def __contains__(self, obj): return True + # __len__ is obliged to return a positive integer. + def __len__(self): return sys.maxsize + us = UniversalSet() + exit_destinations = { + desc.fingerprint: us for desc in exit_candidates } + else: + exit_destinations = {} + for desc in exit_candidates: + policy = have_exit_policy[desc.fingerprint].exit_policy + ok_dests = frozenset(d for d in destinations + if policy.can_exit_to(*d)) + if ok_dests: + exit_destinations[desc.fingerprint] = ok_dests + + logger.info("%d out of %d exit relays meet all filter conditions." + % (len(exit_destinations), len(have_exit_policy))) + return exit_destinations def main(): args = parse_cmd_args() - _, exits = get_exits(args.data_dir, args.countrycode, args.badexit, - args.goodexit, args.version, args.nickname, - args.address) - for e in exits: + exits = get_exits(args.data_dir, + country_code = args.countrycode, + bad_exit = args.badexit, + good_exit = args.goodexit, + version = args.version, + nickname = args.nickname, + address = args.address) + for e in exits.keys(): print("https://atlas.torproject.org/#details/%s" % e) diff --git a/src/selectors34.py b/src/selectors34.py new file mode 100644 index 0000000..1f9c9cb --- /dev/null +++ b/src/selectors34.py @@ -0,0 +1,708 @@ +# -*- encoding: utf-8 -*- +"""Selectors module. + +This module allows high-level and efficient I/O multiplexing, built upon the +`select` module primitives. + +Python 2 backport by Charles-François Natali and Victor Stinner: +https://pypi.python.org/pypi/selectors34 + +""" + +from abc import ABCMeta, abstractmethod +from collections import namedtuple, Mapping +import math +import select +import sys + +import six + +# compatibility code +PY33 = (sys.version_info >= (3, 3)) + + +def _wrap_error(exc, mapping, key): + if key not in mapping: + return + new_err_cls = mapping[key] + new_err = new_err_cls(*exc.args) + + # raise a new exception with the original traceback + if hasattr(exc, '__traceback__'): + traceback = exc.__traceback__ + else: + traceback = sys.exc_info()[2] + six.reraise(new_err_cls, new_err, traceback) + + +if PY33: + import builtins + + BlockingIOError = builtins.BlockingIOError + BrokenPipeError = builtins.BrokenPipeError + ChildProcessError = builtins.ChildProcessError + ConnectionRefusedError = builtins.ConnectionRefusedError + ConnectionResetError = builtins.ConnectionResetError + InterruptedError = builtins.InterruptedError + ConnectionAbortedError = builtins.ConnectionAbortedError + PermissionError = builtins.PermissionError + FileNotFoundError = builtins.FileNotFoundError + ProcessLookupError = builtins.ProcessLookupError + + def wrap_error(func, *args, **kw): + return func(*args, **kw) +else: + import errno + import select + import socket + + class BlockingIOError(OSError): + pass + + class BrokenPipeError(OSError): + pass + + class ChildProcessError(OSError): + pass + + class ConnectionRefusedError(OSError): + pass + + class InterruptedError(OSError): + pass + + class ConnectionResetError(OSError): + pass + + class ConnectionAbortedError(OSError): + pass + + class PermissionError(OSError): + pass + + class FileNotFoundError(OSError): + pass + + class ProcessLookupError(OSError): + pass + + _MAP_ERRNO = { + errno.EACCES: PermissionError, + errno.EAGAIN: BlockingIOError, + errno.EALREADY: BlockingIOError, + errno.ECHILD: ChildProcessError, + errno.ECONNABORTED: ConnectionAbortedError, + errno.ECONNREFUSED: ConnectionRefusedError, + errno.ECONNRESET: ConnectionResetError, + errno.EINPROGRESS: BlockingIOError, + errno.EINTR: InterruptedError, + errno.ENOENT: FileNotFoundError, + errno.EPERM: PermissionError, + errno.EPIPE: BrokenPipeError, + errno.ESHUTDOWN: BrokenPipeError, + errno.EWOULDBLOCK: BlockingIOError, + errno.ESRCH: ProcessLookupError, + } + + def wrap_error(func, *args, **kw): + """ + Wrap socket.error, IOError, OSError, select.error to raise new specialized + exceptions of Python 3.3 like InterruptedError (PEP 3151). + """ + try: + return func(*args, **kw) + except (socket.error, IOError, OSError) as exc: + if hasattr(exc, 'winerror'): + _wrap_error(exc, _MAP_ERRNO, exc.winerror) + # _MAP_ERRNO does not contain all Windows errors. + # For some errors like "file not found", exc.errno should + # be used (ex: ENOENT). + _wrap_error(exc, _MAP_ERRNO, exc.errno) + raise + except select.error as exc: + if exc.args: + _wrap_error(exc, _MAP_ERRNO, exc.args[0]) + raise + +# generic events, that must be mapped to implementation-specific ones +EVENT_READ = (1 << 0) +EVENT_WRITE = (1 << 1) + + +def _fileobj_to_fd(fileobj): + """Return a file descriptor from a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + corresponding file descriptor + + Raises: + ValueError if the object is invalid + """ + if isinstance(fileobj, six.integer_types): + fd = fileobj + else: + try: + fd = int(fileobj.fileno()) + except (AttributeError, TypeError, ValueError): + raise ValueError("Invalid file object: " + "{0!r}".format(fileobj)) + if fd < 0: + raise ValueError("Invalid file descriptor: {0}".format(fd)) + return fd + + +SelectorKey = namedtuple('SelectorKey', ['fileobj', 'fd', 'events', 'data']) +"""Object used to associate a file object to its backing file descriptor, +selected event mask and attached data.""" + + +class _SelectorMapping(Mapping): + """Mapping of file objects to selector keys.""" + + def __init__(self, selector): + self._selector = selector + + def __len__(self): + return len(self._selector._fd_to_key) + + def __getitem__(self, fileobj): + try: + fd = self._selector._fileobj_lookup(fileobj) + return self._selector._fd_to_key[fd] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + def __iter__(self): + return iter(self._selector._fd_to_key) + + +class BaseSelector(six.with_metaclass(ABCMeta)): + """Selector abstract base class. + + A selector supports registering file objects to be monitored for specific + I/O events. + + A file object is a file descriptor or any object with a `fileno()` method. + An arbitrary object can be attached to the file object, which can be used + for example to store context information, a callback, etc. + + A selector can use various implementations (select(), poll(), epoll()...) + depending on the platform. The default `Selector` class uses the most + efficient implementation on the current platform. + """ + + @abstractmethod + def register(self, fileobj, events, data=None): + """Register a file object. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + ValueError if events is invalid + KeyError if fileobj is already registered + OSError if fileobj is closed or otherwise is unacceptable to + the underlying system call (if a system call is made) + + Note: + OSError may or may not be raised + """ + raise NotImplementedError + + @abstractmethod + def unregister(self, fileobj): + """Unregister a file object. + + Parameters: + fileobj -- file object or file descriptor + + Returns: + SelectorKey instance + + Raises: + KeyError if fileobj is not registered + + Note: + If fileobj is registered but has since been closed this does + *not* raise OSError (even if the wrapped syscall does) + """ + raise NotImplementedError + + def modify(self, fileobj, events, data=None): + """Change a registered file object monitored events or attached data. + + Parameters: + fileobj -- file object or file descriptor + events -- events to monitor (bitwise mask of EVENT_READ|EVENT_WRITE) + data -- attached data + + Returns: + SelectorKey instance + + Raises: + Anything that unregister() or register() raises + """ + self.unregister(fileobj) + return self.register(fileobj, events, data) + + @abstractmethod + def select(self, timeout=None): + """Perform the actual selection, until some monitored file objects are + ready or a timeout expires. + + Parameters: + timeout -- if timeout > 0, this specifies the maximum wait time, in + seconds + if timeout <= 0, the select() call won't block, and will + report the currently ready file objects + if timeout is None, select() will block until a monitored + file object becomes ready + + Returns: + list of (key, events) for ready file objects + `events` is a bitwise mask of EVENT_READ|EVENT_WRITE + """ + raise NotImplementedError + + def close(self): + """Close the selector. + + This must be called to make sure that any underlying resource is freed. + """ + pass + + def get_key(self, fileobj): + """Return the key associated to a registered file object. + + Returns: + SelectorKey for this file object + """ + mapping = self.get_map() + if mapping is None: + raise RuntimeError('Selector is closed') + try: + return mapping[fileobj] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + + @abstractmethod + def get_map(self): + """Return a mapping of file objects to selector keys.""" + raise NotImplementedError + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + +class _BaseSelectorImpl(BaseSelector): + """Base selector implementation.""" + + def __init__(self): + # this maps file descriptors to keys + self._fd_to_key = {} + # read-only mapping returned by get_map() + self._map = _SelectorMapping(self) + + def _fileobj_lookup(self, fileobj): + """Return a file descriptor from a file object. + + This wraps _fileobj_to_fd() to do an exhaustive search in case + the object is invalid but we still have it in our map. This + is used by unregister() so we can unregister an object that + was previously registered even if it is closed. It is also + used by _SelectorMapping. + """ + try: + return _fileobj_to_fd(fileobj) + except ValueError: + # Do an exhaustive search. + for key in self._fd_to_key.values(): + if key.fileobj is fileobj: + return key.fd + # Raise ValueError after all. + raise + + def register(self, fileobj, events, data=None): + if (not events) or (events & ~(EVENT_READ | EVENT_WRITE)): + raise ValueError("Invalid events: {0!r}".format(events)) + + key = SelectorKey(fileobj, self._fileobj_lookup(fileobj), events, data) + + if key.fd in self._fd_to_key: + raise KeyError("{0!r} (FD {1}) is already registered" + .format(fileobj, key.fd)) + + self._fd_to_key[key.fd] = key + return key + + def unregister(self, fileobj): + try: + key = self._fd_to_key.pop(self._fileobj_lookup(fileobj)) + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + return key + + def modify(self, fileobj, events, data=None): + # TODO: Subclasses can probably optimize this even further. + try: + key = self._fd_to_key[self._fileobj_lookup(fileobj)] + except KeyError: + raise KeyError("{0!r} is not registered".format(fileobj)) + if events != key.events: + self.unregister(fileobj) + key = self.register(fileobj, events, data) + elif data != key.data: + # Use a shortcut to update the data. + key = key._replace(data=data) + self._fd_to_key[key.fd] = key + return key + + def close(self): + self._fd_to_key.clear() + self._map = None + + def get_map(self): + return self._map + + def _key_from_fd(self, fd): + """Return the key associated to a given file descriptor. + + Parameters: + fd -- file descriptor + + Returns: + corresponding key, or None if not found + """ + try: + return self._fd_to_key[fd] + except KeyError: + return None + + +class SelectSelector(_BaseSelectorImpl): + """Select-based selector.""" + + def __init__(self): + super(SelectSelector, self).__init__() + self._readers = set() + self._writers = set() + + def register(self, fileobj, events, data=None): + key = super(SelectSelector, self).register(fileobj, events, data) + if events & EVENT_READ: + self._readers.add(key.fd) + if events & EVENT_WRITE: + self._writers.add(key.fd) + return key + + def unregister(self, fileobj): + key = super(SelectSelector, self).unregister(fileobj) + self._readers.discard(key.fd) + self._writers.discard(key.fd) + return key + + if sys.platform == 'win32': + def _select(self, r, w, _, timeout=None): + r, w, x = select.select(r, w, w, timeout) + return r, w + x, [] + else: + _select = select.select + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + ready = [] + try: + r, w, _ = wrap_error(self._select, + self._readers, self._writers, [], timeout) + except InterruptedError: + return ready + r = set(r) + w = set(w) + for fd in r | w: + events = 0 + if fd in r: + events |= EVENT_READ + if fd in w: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'poll'): + + class PollSelector(_BaseSelectorImpl): + """Poll-based selector.""" + + def __init__(self): + super(PollSelector, self).__init__() + self._poll = select.poll() + + def register(self, fileobj, events, data=None): + key = super(PollSelector, self).register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._poll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super(PollSelector, self).unregister(fileobj) + self._poll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # poll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = int(math.ceil(timeout * 1e3)) + ready = [] + try: + fd_event_list = wrap_error(self._poll.poll, timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + +if hasattr(select, 'epoll'): + + class EpollSelector(_BaseSelectorImpl): + """Epoll-based selector.""" + + def __init__(self): + super(EpollSelector, self).__init__() + self._epoll = select.epoll() + + def fileno(self): + return self._epoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(EpollSelector, self).register(fileobj, events, data) + epoll_events = 0 + if events & EVENT_READ: + epoll_events |= select.EPOLLIN + if events & EVENT_WRITE: + epoll_events |= select.EPOLLOUT + self._epoll.register(key.fd, epoll_events) + return key + + def unregister(self, fileobj): + key = super(EpollSelector, self).unregister(fileobj) + try: + self._epoll.unregister(key.fd) + except IOError: + # This can happen if the FD was closed since it + # was registered. + pass + return key + + def select(self, timeout=None): + if timeout is None: + timeout = -1 + elif timeout <= 0: + timeout = 0 + else: + # epoll_wait() has a resolution of 1 millisecond, round away + # from zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) * 1e-3 + + # epoll_wait() expects `maxevents` to be greater than zero; + # we want to make sure that `select()` can be called when no + # FD is registered. + max_ev = max(len(self._fd_to_key), 1) + + ready = [] + try: + fd_event_list = wrap_error(self._epoll.poll, timeout, max_ev) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.EPOLLIN: + events |= EVENT_WRITE + if event & ~select.EPOLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._epoll.close() + super(EpollSelector, self).close() + + +if hasattr(select, 'devpoll'): + + class DevpollSelector(_BaseSelectorImpl): + """Solaris /dev/poll selector.""" + + def __init__(self): + super(DevpollSelector, self).__init__() + self._devpoll = select.devpoll() + + def fileno(self): + return self._devpoll.fileno() + + def register(self, fileobj, events, data=None): + key = super(DevpollSelector, self).register(fileobj, events, data) + poll_events = 0 + if events & EVENT_READ: + poll_events |= select.POLLIN + if events & EVENT_WRITE: + poll_events |= select.POLLOUT + self._devpoll.register(key.fd, poll_events) + return key + + def unregister(self, fileobj): + key = super(DevpollSelector, self).unregister(fileobj) + self._devpoll.unregister(key.fd) + return key + + def select(self, timeout=None): + if timeout is None: + timeout = None + elif timeout <= 0: + timeout = 0 + else: + # devpoll() has a resolution of 1 millisecond, round away from + # zero to wait *at least* timeout seconds. + timeout = math.ceil(timeout * 1e3) + ready = [] + try: + fd_event_list = self._devpoll.poll(timeout) + except InterruptedError: + return ready + for fd, event in fd_event_list: + events = 0 + if event & ~select.POLLIN: + events |= EVENT_WRITE + if event & ~select.POLLOUT: + events |= EVENT_READ + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._devpoll.close() + super(DevpollSelector, self).close() + + +if hasattr(select, 'kqueue'): + + class KqueueSelector(_BaseSelectorImpl): + """Kqueue-based selector.""" + + def __init__(self): + super(KqueueSelector, self).__init__() + self._kqueue = select.kqueue() + + def fileno(self): + return self._kqueue.fileno() + + def register(self, fileobj, events, data=None): + key = super(KqueueSelector, self).register(fileobj, events, data) + if events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + if events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_ADD) + self._kqueue.control([kev], 0, 0) + return key + + def unregister(self, fileobj): + key = super(KqueueSelector, self).unregister(fileobj) + if key.events & EVENT_READ: + kev = select.kevent(key.fd, select.KQ_FILTER_READ, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # This can happen if the FD was closed since it + # was registered. + pass + if key.events & EVENT_WRITE: + kev = select.kevent(key.fd, select.KQ_FILTER_WRITE, + select.KQ_EV_DELETE) + try: + self._kqueue.control([kev], 0, 0) + except OSError: + # See comment above. + pass + return key + + def select(self, timeout=None): + timeout = None if timeout is None else max(timeout, 0) + max_ev = len(self._fd_to_key) + ready = [] + try: + kev_list = wrap_error(self._kqueue.control, + None, max_ev, timeout) + except InterruptedError: + return ready + for kev in kev_list: + fd = kev.ident + flag = kev.filter + events = 0 + if flag == select.KQ_FILTER_READ: + events |= EVENT_READ + if flag == select.KQ_FILTER_WRITE: + events |= EVENT_WRITE + + key = self._key_from_fd(fd) + if key: + ready.append((key, events & key.events)) + return ready + + def close(self): + self._kqueue.close() + super(KqueueSelector, self).close() + + +# Choose the best implementation, roughly: +# epoll|kqueue|devpoll > poll > select. +# select() also can't accept a FD > FD_SETSIZE (usually around 1024) +if 'KqueueSelector' in globals(): + DefaultSelector = KqueueSelector +elif 'EpollSelector' in globals(): + DefaultSelector = EpollSelector +elif 'DevpollSelector' in globals(): + DefaultSelector = DevpollSelector +elif 'PollSelector' in globals(): + DefaultSelector = PollSelector +else: + DefaultSelector = SelectSelector diff --git a/src/six.py b/src/six.py new file mode 100644 index 0000000..6ca60a2 --- /dev/null +++ b/src/six.py @@ -0,0 +1,868 @@ +# Copyright (c) 2010-2016 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Utilities for writing code that runs on Python 2 and 3""" + +from __future__ import absolute_import + +import functools +import itertools +import operator +import sys +import types + +__author__ = "Benjamin Peterson " +__version__ = "1.10.0" + + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 +PY34 = sys.version_info[0:2] >= (3, 4) + +if PY3: + string_types = str, + integer_types = int, + class_types = type, + text_type = str + binary_type = bytes + + MAXSIZE = sys.maxsize +else: + string_types = basestring, + integer_types = (int, long) + class_types = (type, types.ClassType) + text_type = unicode + binary_type = str + + if sys.platform.startswith("java"): + # Jython always uses 32 bits. + MAXSIZE = int((1 << 31) - 1) + else: + # It's possible to have sizeof(long) != sizeof(Py_ssize_t). + class X(object): + + def __len__(self): + return 1 << 31 + try: + len(X()) + except OverflowError: + # 32-bit + MAXSIZE = int((1 << 31) - 1) + else: + # 64-bit + MAXSIZE = int((1 << 63) - 1) + del X + + +def _add_doc(func, doc): + """Add documentation to a function.""" + func.__doc__ = doc + + +def _import_module(name): + """Import module, returning the module after the last dot.""" + __import__(name) + return sys.modules[name] + + +class _LazyDescr(object): + + def __init__(self, name): + self.name = name + + def __get__(self, obj, tp): + result = self._resolve() + setattr(obj, self.name, result) # Invokes __set__. + try: + # This is a bit ugly, but it avoids running this again by + # removing this descriptor. + delattr(obj.__class__, self.name) + except AttributeError: + pass + return result + + +class MovedModule(_LazyDescr): + + def __init__(self, name, old, new=None): + super(MovedModule, self).__init__(name) + if PY3: + if new is None: + new = name + self.mod = new + else: + self.mod = old + + def _resolve(self): + return _import_module(self.mod) + + def __getattr__(self, attr): + _module = self._resolve() + value = getattr(_module, attr) + setattr(self, attr, value) + return value + + +class _LazyModule(types.ModuleType): + + def __init__(self, name): + super(_LazyModule, self).__init__(name) + self.__doc__ = self.__class__.__doc__ + + def __dir__(self): + attrs = ["__doc__", "__name__"] + attrs += [attr.name for attr in self._moved_attributes] + return attrs + + # Subclasses should override this + _moved_attributes = [] + + +class MovedAttribute(_LazyDescr): + + def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): + super(MovedAttribute, self).__init__(name) + if PY3: + if new_mod is None: + new_mod = name + self.mod = new_mod + if new_attr is None: + if old_attr is None: + new_attr = name + else: + new_attr = old_attr + self.attr = new_attr + else: + self.mod = old_mod + if old_attr is None: + old_attr = name + self.attr = old_attr + + def _resolve(self): + module = _import_module(self.mod) + return getattr(module, self.attr) + + +class _SixMetaPathImporter(object): + + """ + A meta path importer to import six.moves and its submodules. + + This class implements a PEP302 finder and loader. It should be compatible + with Python 2.5 and all existing versions of Python3 + """ + + def __init__(self, six_module_name): + self.name = six_module_name + self.known_modules = {} + + def _add_module(self, mod, *fullnames): + for fullname in fullnames: + self.known_modules[self.name + "." + fullname] = mod + + def _get_module(self, fullname): + return self.known_modules[self.name + "." + fullname] + + def find_module(self, fullname, path=None): + if fullname in self.known_modules: + return self + return None + + def __get_module(self, fullname): + try: + return self.known_modules[fullname] + except KeyError: + raise ImportError("This loader does not know module " + fullname) + + def load_module(self, fullname): + try: + # in case of a reload + return sys.modules[fullname] + except KeyError: + pass + mod = self.__get_module(fullname) + if isinstance(mod, MovedModule): + mod = mod._resolve() + else: + mod.__loader__ = self + sys.modules[fullname] = mod + return mod + + def is_package(self, fullname): + """ + Return true, if the named module is a package. + + We need this method to get correct spec objects with + Python 3.4 (see PEP451) + """ + return hasattr(self.__get_module(fullname), "__path__") + + def get_code(self, fullname): + """Return None + + Required, if is_package is implemented""" + self.__get_module(fullname) # eventually raises ImportError + return None + get_source = get_code # same as get_code + +_importer = _SixMetaPathImporter(__name__) + + +class _MovedItems(_LazyModule): + + """Lazy loading of moved objects""" + __path__ = [] # mark as package + + +_moved_attributes = [ + MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), + MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), + MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), + MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), + MovedAttribute("intern", "__builtin__", "sys"), + MovedAttribute("map", "itertools", "builtins", "imap", "map"), + MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), + MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), + MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), + MovedAttribute("reduce", "__builtin__", "functools"), + MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), + MovedAttribute("StringIO", "StringIO", "io"), + MovedAttribute("UserDict", "UserDict", "collections"), + MovedAttribute("UserList", "UserList", "collections"), + MovedAttribute("UserString", "UserString", "collections"), + MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), + MovedModule("builtins", "__builtin__"), + MovedModule("configparser", "ConfigParser"), + MovedModule("copyreg", "copy_reg"), + MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), + MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), + MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), + MovedModule("http_cookies", "Cookie", "http.cookies"), + MovedModule("html_entities", "htmlentitydefs", "html.entities"), + MovedModule("html_parser", "HTMLParser", "html.parser"), + MovedModule("http_client", "httplib", "http.client"), + MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), + MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), + MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), + MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), + MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), + MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), + MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), + MovedModule("cPickle", "cPickle", "pickle"), + MovedModule("queue", "Queue"), + MovedModule("reprlib", "repr"), + MovedModule("socketserver", "SocketServer"), + MovedModule("_thread", "thread", "_thread"), + MovedModule("tkinter", "Tkinter"), + MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), + MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), + MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), + MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), + MovedModule("tkinter_tix", "Tix", "tkinter.tix"), + MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), + MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), + MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), + MovedModule("tkinter_colorchooser", "tkColorChooser", + "tkinter.colorchooser"), + MovedModule("tkinter_commondialog", "tkCommonDialog", + "tkinter.commondialog"), + MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), + MovedModule("tkinter_font", "tkFont", "tkinter.font"), + MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), + MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", + "tkinter.simpledialog"), + MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), + MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), + MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), + MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), + MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), + MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), +] +# Add windows specific modules. +if sys.platform == "win32": + _moved_attributes += [ + MovedModule("winreg", "_winreg"), + ] + +for attr in _moved_attributes: + setattr(_MovedItems, attr.name, attr) + if isinstance(attr, MovedModule): + _importer._add_module(attr, "moves." + attr.name) +del attr + +_MovedItems._moved_attributes = _moved_attributes + +moves = _MovedItems(__name__ + ".moves") +_importer._add_module(moves, "moves") + + +class Module_six_moves_urllib_parse(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_parse""" + + +_urllib_parse_moved_attributes = [ + MovedAttribute("ParseResult", "urlparse", "urllib.parse"), + MovedAttribute("SplitResult", "urlparse", "urllib.parse"), + MovedAttribute("parse_qs", "urlparse", "urllib.parse"), + MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), + MovedAttribute("urldefrag", "urlparse", "urllib.parse"), + MovedAttribute("urljoin", "urlparse", "urllib.parse"), + MovedAttribute("urlparse", "urlparse", "urllib.parse"), + MovedAttribute("urlsplit", "urlparse", "urllib.parse"), + MovedAttribute("urlunparse", "urlparse", "urllib.parse"), + MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), + MovedAttribute("quote", "urllib", "urllib.parse"), + MovedAttribute("quote_plus", "urllib", "urllib.parse"), + MovedAttribute("unquote", "urllib", "urllib.parse"), + MovedAttribute("unquote_plus", "urllib", "urllib.parse"), + MovedAttribute("urlencode", "urllib", "urllib.parse"), + MovedAttribute("splitquery", "urllib", "urllib.parse"), + MovedAttribute("splittag", "urllib", "urllib.parse"), + MovedAttribute("splituser", "urllib", "urllib.parse"), + MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), + MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), + MovedAttribute("uses_params", "urlparse", "urllib.parse"), + MovedAttribute("uses_query", "urlparse", "urllib.parse"), + MovedAttribute("uses_relative", "urlparse", "urllib.parse"), +] +for attr in _urllib_parse_moved_attributes: + setattr(Module_six_moves_urllib_parse, attr.name, attr) +del attr + +Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes + +_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), + "moves.urllib_parse", "moves.urllib.parse") + + +class Module_six_moves_urllib_error(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_error""" + + +_urllib_error_moved_attributes = [ + MovedAttribute("URLError", "urllib2", "urllib.error"), + MovedAttribute("HTTPError", "urllib2", "urllib.error"), + MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), +] +for attr in _urllib_error_moved_attributes: + setattr(Module_six_moves_urllib_error, attr.name, attr) +del attr + +Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes + +_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), + "moves.urllib_error", "moves.urllib.error") + + +class Module_six_moves_urllib_request(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_request""" + + +_urllib_request_moved_attributes = [ + MovedAttribute("urlopen", "urllib2", "urllib.request"), + MovedAttribute("install_opener", "urllib2", "urllib.request"), + MovedAttribute("build_opener", "urllib2", "urllib.request"), + MovedAttribute("pathname2url", "urllib", "urllib.request"), + MovedAttribute("url2pathname", "urllib", "urllib.request"), + MovedAttribute("getproxies", "urllib", "urllib.request"), + MovedAttribute("Request", "urllib2", "urllib.request"), + MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), + MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), + MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), + MovedAttribute("BaseHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), + MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), + MovedAttribute("FileHandler", "urllib2", "urllib.request"), + MovedAttribute("FTPHandler", "urllib2", "urllib.request"), + MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), + MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), + MovedAttribute("urlretrieve", "urllib", "urllib.request"), + MovedAttribute("urlcleanup", "urllib", "urllib.request"), + MovedAttribute("URLopener", "urllib", "urllib.request"), + MovedAttribute("FancyURLopener", "urllib", "urllib.request"), + MovedAttribute("proxy_bypass", "urllib", "urllib.request"), +] +for attr in _urllib_request_moved_attributes: + setattr(Module_six_moves_urllib_request, attr.name, attr) +del attr + +Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes + +_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), + "moves.urllib_request", "moves.urllib.request") + + +class Module_six_moves_urllib_response(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_response""" + + +_urllib_response_moved_attributes = [ + MovedAttribute("addbase", "urllib", "urllib.response"), + MovedAttribute("addclosehook", "urllib", "urllib.response"), + MovedAttribute("addinfo", "urllib", "urllib.response"), + MovedAttribute("addinfourl", "urllib", "urllib.response"), +] +for attr in _urllib_response_moved_attributes: + setattr(Module_six_moves_urllib_response, attr.name, attr) +del attr + +Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes + +_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), + "moves.urllib_response", "moves.urllib.response") + + +class Module_six_moves_urllib_robotparser(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_robotparser""" + + +_urllib_robotparser_moved_attributes = [ + MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), +] +for attr in _urllib_robotparser_moved_attributes: + setattr(Module_six_moves_urllib_robotparser, attr.name, attr) +del attr + +Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes + +_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), + "moves.urllib_robotparser", "moves.urllib.robotparser") + + +class Module_six_moves_urllib(types.ModuleType): + + """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" + __path__ = [] # mark as package + parse = _importer._get_module("moves.urllib_parse") + error = _importer._get_module("moves.urllib_error") + request = _importer._get_module("moves.urllib_request") + response = _importer._get_module("moves.urllib_response") + robotparser = _importer._get_module("moves.urllib_robotparser") + + def __dir__(self): + return ['parse', 'error', 'request', 'response', 'robotparser'] + +_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), + "moves.urllib") + + +def add_move(move): + """Add an item to six.moves.""" + setattr(_MovedItems, move.name, move) + + +def remove_move(name): + """Remove item from six.moves.""" + try: + delattr(_MovedItems, name) + except AttributeError: + try: + del moves.__dict__[name] + except KeyError: + raise AttributeError("no such move, %r" % (name,)) + + +if PY3: + _meth_func = "__func__" + _meth_self = "__self__" + + _func_closure = "__closure__" + _func_code = "__code__" + _func_defaults = "__defaults__" + _func_globals = "__globals__" +else: + _meth_func = "im_func" + _meth_self = "im_self" + + _func_closure = "func_closure" + _func_code = "func_code" + _func_defaults = "func_defaults" + _func_globals = "func_globals" + + +try: + advance_iterator = next +except NameError: + def advance_iterator(it): + return it.next() +next = advance_iterator + + +try: + callable = callable +except NameError: + def callable(obj): + return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) + + +if PY3: + def get_unbound_function(unbound): + return unbound + + create_bound_method = types.MethodType + + def create_unbound_method(func, cls): + return func + + Iterator = object +else: + def get_unbound_function(unbound): + return unbound.im_func + + def create_bound_method(func, obj): + return types.MethodType(func, obj, obj.__class__) + + def create_unbound_method(func, cls): + return types.MethodType(func, None, cls) + + class Iterator(object): + + def next(self): + return type(self).__next__(self) + + callable = callable +_add_doc(get_unbound_function, + """Get the function out of a possibly unbound function""") + + +get_method_function = operator.attrgetter(_meth_func) +get_method_self = operator.attrgetter(_meth_self) +get_function_closure = operator.attrgetter(_func_closure) +get_function_code = operator.attrgetter(_func_code) +get_function_defaults = operator.attrgetter(_func_defaults) +get_function_globals = operator.attrgetter(_func_globals) + + +if PY3: + def iterkeys(d, **kw): + return iter(d.keys(**kw)) + + def itervalues(d, **kw): + return iter(d.values(**kw)) + + def iteritems(d, **kw): + return iter(d.items(**kw)) + + def iterlists(d, **kw): + return iter(d.lists(**kw)) + + viewkeys = operator.methodcaller("keys") + + viewvalues = operator.methodcaller("values") + + viewitems = operator.methodcaller("items") +else: + def iterkeys(d, **kw): + return d.iterkeys(**kw) + + def itervalues(d, **kw): + return d.itervalues(**kw) + + def iteritems(d, **kw): + return d.iteritems(**kw) + + def iterlists(d, **kw): + return d.iterlists(**kw) + + viewkeys = operator.methodcaller("viewkeys") + + viewvalues = operator.methodcaller("viewvalues") + + viewitems = operator.methodcaller("viewitems") + +_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") +_add_doc(itervalues, "Return an iterator over the values of a dictionary.") +_add_doc(iteritems, + "Return an iterator over the (key, value) pairs of a dictionary.") +_add_doc(iterlists, + "Return an iterator over the (key, [values]) pairs of a dictionary.") + + +if PY3: + def b(s): + return s.encode("latin-1") + + def u(s): + return s + unichr = chr + import struct + int2byte = struct.Struct(">B").pack + del struct + byte2int = operator.itemgetter(0) + indexbytes = operator.getitem + iterbytes = iter + import io + StringIO = io.StringIO + BytesIO = io.BytesIO + _assertCountEqual = "assertCountEqual" + if sys.version_info[1] <= 1: + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" + else: + _assertRaisesRegex = "assertRaisesRegex" + _assertRegex = "assertRegex" +else: + def b(s): + return s + # Workaround for standalone backslash + + def u(s): + return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") + unichr = unichr + int2byte = chr + + def byte2int(bs): + return ord(bs[0]) + + def indexbytes(buf, i): + return ord(buf[i]) + iterbytes = functools.partial(itertools.imap, ord) + import StringIO + StringIO = BytesIO = StringIO.StringIO + _assertCountEqual = "assertItemsEqual" + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" +_add_doc(b, """Byte literal""") +_add_doc(u, """Text literal""") + + +def assertCountEqual(self, *args, **kwargs): + return getattr(self, _assertCountEqual)(*args, **kwargs) + + +def assertRaisesRegex(self, *args, **kwargs): + return getattr(self, _assertRaisesRegex)(*args, **kwargs) + + +def assertRegex(self, *args, **kwargs): + return getattr(self, _assertRegex)(*args, **kwargs) + + +if PY3: + exec_ = getattr(moves.builtins, "exec") + + def reraise(tp, value, tb=None): + if value is None: + value = tp() + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + +else: + def exec_(_code_, _globs_=None, _locs_=None): + """Execute code in a namespace.""" + if _globs_ is None: + frame = sys._getframe(1) + _globs_ = frame.f_globals + if _locs_ is None: + _locs_ = frame.f_locals + del frame + elif _locs_ is None: + _locs_ = _globs_ + exec("""exec _code_ in _globs_, _locs_""") + + exec_("""def reraise(tp, value, tb=None): + raise tp, value, tb +""") + + +if sys.version_info[:2] == (3, 2): + exec_("""def raise_from(value, from_value): + if from_value is None: + raise value + raise value from from_value +""") +elif sys.version_info[:2] > (3, 2): + exec_("""def raise_from(value, from_value): + raise value from from_value +""") +else: + def raise_from(value, from_value): + raise value + + +print_ = getattr(moves.builtins, "print", None) +if print_ is None: + def print_(*args, **kwargs): + """The new-style print function for Python 2.4 and 2.5.""" + fp = kwargs.pop("file", sys.stdout) + if fp is None: + return + + def write(data): + if not isinstance(data, basestring): + data = str(data) + # If the file has an encoding, encode unicode with it. + if (isinstance(fp, file) and + isinstance(data, unicode) and + fp.encoding is not None): + errors = getattr(fp, "errors", None) + if errors is None: + errors = "strict" + data = data.encode(fp.encoding, errors) + fp.write(data) + want_unicode = False + sep = kwargs.pop("sep", None) + if sep is not None: + if isinstance(sep, unicode): + want_unicode = True + elif not isinstance(sep, str): + raise TypeError("sep must be None or a string") + end = kwargs.pop("end", None) + if end is not None: + if isinstance(end, unicode): + want_unicode = True + elif not isinstance(end, str): + raise TypeError("end must be None or a string") + if kwargs: + raise TypeError("invalid keyword arguments to print()") + if not want_unicode: + for arg in args: + if isinstance(arg, unicode): + want_unicode = True + break + if want_unicode: + newline = unicode("\n") + space = unicode(" ") + else: + newline = "\n" + space = " " + if sep is None: + sep = space + if end is None: + end = newline + for i, arg in enumerate(args): + if i: + write(sep) + write(arg) + write(end) +if sys.version_info[:2] < (3, 3): + _print = print_ + + def print_(*args, **kwargs): + fp = kwargs.get("file", sys.stdout) + flush = kwargs.pop("flush", False) + _print(*args, **kwargs) + if flush and fp is not None: + fp.flush() + +_add_doc(reraise, """Reraise an exception.""") + +if sys.version_info[0:2] < (3, 4): + def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, + updated=functools.WRAPPER_UPDATES): + def wrapper(f): + f = functools.wraps(wrapped, assigned, updated)(f) + f.__wrapped__ = wrapped + return f + return wrapper +else: + wraps = functools.wraps + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(meta): + + def __new__(cls, name, this_bases, d): + return meta(name, bases, d) + return type.__new__(metaclass, 'temporary_class', (), {}) + + +def add_metaclass(metaclass): + """Class decorator for creating a class with a metaclass.""" + def wrapper(cls): + orig_vars = cls.__dict__.copy() + slots = orig_vars.get('__slots__') + if slots is not None: + if isinstance(slots, str): + slots = [slots] + for slots_var in slots: + orig_vars.pop(slots_var) + orig_vars.pop('__dict__', None) + orig_vars.pop('__weakref__', None) + return metaclass(cls.__name__, cls.__bases__, orig_vars) + return wrapper + + +def python_2_unicode_compatible(klass): + """ + A decorator that defines __unicode__ and __str__ methods under Python 2. + Under Python 3 it does nothing. + + To support Python 2 and 3 with a single code base, define a __str__ method + returning text and apply this decorator to the class. + """ + if PY2: + if '__str__' not in klass.__dict__: + raise ValueError("@python_2_unicode_compatible cannot be applied " + "to %s because it doesn't define __str__()." % + klass.__name__) + klass.__unicode__ = klass.__str__ + klass.__str__ = lambda self: self.__unicode__().encode('utf-8') + return klass + + +# Complete the moves implementation. +# This code is at the end of this module to speed up module loading. +# Turn this module into a package. +__path__ = [] # required for PEP 302 and PEP 451 +__package__ = __name__ # see PEP 366 @ReservedAssignment +if globals().get("__spec__") is not None: + __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable +# Remove other six meta path importers, since they cause problems. This can +# happen if six is removed from sys.modules and then reloaded. (Setuptools does +# this for some reason.) +if sys.meta_path: + for i, importer in enumerate(sys.meta_path): + # Here's some real nastiness: Another "instance" of the six module might + # be floating around. Therefore, we can't use isinstance() to check for + # the six meta path importer, since the other six instance will have + # inserted an importer with different class. + if (type(importer).__name__ == "_SixMetaPathImporter" and + importer.name == __name__): + del sys.meta_path[i] + break + del i, importer +# Finally, add the importer to the meta path import hook. +sys.meta_path.append(_importer) diff --git a/src/torsocks.py b/src/torsocks.py index f51e9d1..cf462f8 100644 --- a/src/torsocks.py +++ b/src/torsocks.py @@ -22,6 +22,8 @@ import sys import struct import socket +import select +import errno import error import log @@ -30,76 +32,96 @@ proxy_addr = None proxy_port = None -queue = None -circ_id = None +queue = None +circ_id = None orig_socket = socket.socket -# Server-side SOCKSv5 errors. +_ERRNO_RETRY = frozenset((errno.EAGAIN, errno.EWOULDBLOCK, + errno.EINPROGRESS, errno.EINTR)) + +_LOCAL_SOCKETS = frozenset( + getattr(socket, af) for af in [ + 'AF_UNIX', 'AF_LOCAL', + 'AF_ROUTE', 'AF_KEY', 'AF_ALG', 'AF_NETLINK' + ] + if hasattr(socket, af) +) + +# Map server-side SOCKSv5 errors to errno codes (as best we can; codes +# 1 and 7 don't correspond to documented error codes for connect(2)) socks5_errors = { - 0x00: "Request granted", - 0x01: "General failure", - 0x02: "Connection not allowed by ruleset", - 0x03: "Network unreachable", - 0x04: "Host unreachable", - 0x05: "Connection refused by destination host", - 0x06: "TTL expired", - 0x07: "Command not supported / protocol error", - 0x08: "Address type not supported", + 0x00: 0, # Success + 0x01: errno.EIO, # General failure + 0x02: errno.EACCES, # Connection not allowed by ruleset + 0x03: errno.ENETUNREACH, # Network unreachable + 0x04: errno.EHOSTUNREACH, # Host unreachable + 0x05: errno.ECONNREFUSED, # Connection refused by destination host + 0x06: errno.ETIMEDOUT, # TTL expired + 0x07: errno.ENOTSUP, # Command not supported / protocol error + 0x08: errno.EAFNOSUPPORT, # Address type not supported } def send_queue(sock_name): """ Inform caller about our newly created socket. - - We need to temporarily use the original socket.socket implementation for - the queue to work. """ - global queue, circ_id, orig_socket + global queue, circ_id assert (queue is not None) and (circ_id is not None) - tmp_socket = socket.socket - socket.socket = orig_socket - queue.put([circ_id, sock_name]) - socket.socket = tmp_socket - - -def set_default_proxy(ip_addr, port): - """ - Set the SOCKS proxy address and its port. - """ - - global proxy_addr, proxy_port - proxy_addr, proxy_port = ip_addr, port - - -class torsocket(socket.socket): +class _Torsocket(orig_socket): """ Provides a minimal, Tor-specific SOCKSv5 interface. """ + # Implementation note: socket.socket is (at least in Python 2) a + # wrapper object around _socket.socket. Most superclass methods + # cannot be invoked via the usual super().method(self, args...) + # construct. One must use self._sock.method(args...) instead. + def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, _sock=None): - self.sockfamily = family - self.socktype = type - - super(torsocket, self).__init__(family, type, proto, _sock) + self._sockfamily = family + self._socktype = type + self._connecting = False + self._connected = False + self._peer_addr = None + self._conn_err = None + + super(_Torsocket, self).__init__(family, type, proto, _sock) + + # FIXME: Arguably this should happen only on connect() so that + # attempts to connect to 127.0.0.1 can bypass the proxy server. + # However, that would make nonblocking mode significantly more + # complicated. We'd need an actual state machine instead of + # just a pair of booleans, and callers would need to be + # prepared to 'turn the crank' on the state machine. + self._authenticate() def _recv_all(self, num_bytes): """ - Try to read the given number of bytes. + Try to read the given number of bytes, blocking indefinitely + if necessary (even if the socket is in nonblocking mode). If we are unable to read all of it, an EOFError is raised. """ data = "" while len(data) < num_bytes: - more = self.recv(num_bytes - len(data)) + try: + more = self._sock.recv(num_bytes - len(data)) + except socket.error as e: + if e.errno not in _ERRNO_RETRY: + raise + + select.select([self], [], []) + continue + if not more: raise EOFError("Could read only %d of expected %d bytes." % (len(data), num_bytes)) @@ -107,6 +129,28 @@ def _recv_all(self, num_bytes): return data + def _send_all(self, msg): + """ + Try to send all of 'msg', blocking indefinitely if necessary + (even if the socket is in nonblocking mode). + """ + + sent = 0 + while sent < len(msg): + try: + n = self._sock.send(msg[sent:]) + except socket.error as e: + if e.errno not in _ERRNO_RETRY: + raise + + select.select([], [self], []) + continue + + if not n: + raise EOFError("Could send only %d of expected %d bytes." % + (sent, len(msg))) + sent += n + def _authenticate(self): """ Authenticate to our SOCKSv5 server. @@ -117,14 +161,8 @@ def _authenticate(self): # Connect to SOCKSv5 server. We use version 5 and one authentication # method, which is "no authentication". - try: - orig_socket.connect(self, (proxy_addr, proxy_port)) - except Exception as err: - logger.warning("connect() failed: %s" % err) - sys.exit(1) - - self.sendall("\x05\x01\x00") - + self._sock.connect((proxy_addr, proxy_port)) + self._send_all("\x05\x01\x00") resp = self._recv_all(2) if resp != "\x05\x00": raise error.SOCKSv5Error("Invalid server response: 0x%s" % @@ -145,8 +183,7 @@ def resolve(self, domain): # Tor defines a new command value, \x0f, that is used for domain # resolution. - self._authenticate() - self.sendall("\x05\xf0\x00\x03%s%s%s" % + self._send_all("\x05\xf0\x00\x03%s%s%s" % (chr(domain_len), domain, "\x00\x00")) resp = self._recv_all(10) @@ -157,38 +194,218 @@ def resolve(self, domain): return socket.inet_ntoa(resp[4:8]) def connect(self, addr_tuple): + err = self.connect_ex(addr_tuple) + if err: + raise socket.error(err, os.strerror(err)) + + def connect_ex(self, addr_tuple): """ Tell SOCKS server to connect to our destination. """ dst_addr, dst_port = addr_tuple[0], int(addr_tuple[1]) + self._connecting = True + self._peer_addr = (dst_addr, dst_port) - self._authenticate() + logger.debug("Requesting connection to %s:%d.", dst_addr, dst_port) - # Tell SOCKS server to connect to destination. - - self.sendall("\x05\x01\x00\x01%s%s" % + self._send_all("\x05\x01\x00\x01%s%s" % (socket.inet_aton(dst_addr), struct.pack(">H", dst_port))) - resp = self._recv_all(4) - if resp[1] != "\x00": - val = int(resp[1].encode("hex"), 16) - if 0 <= val < len(socks5_errors): - raise error.SOCKSv5Error("SOCKSv5 connection failed because: " - "%s" % socks5_errors[val]) - else: - raise error.SOCKSv5Error("Unexpected SOCKSv5 error: %d" % val) + return self._attempt_finish_socks_handshake() - # Depending on address type, get address. + def _attempt_finish_socks_handshake(self): + # Receive the first byte of the server reply using the + # underlying recv() primitive, and suspend this operation if + # it comes back with EAGAIN, or fail it if it gives an error. + # Callers of connect_ex expect to get EINPROGRESS, not EAGAIN. + logger.debug("Attempting to read SOCKS reply.") + try: + resp0 = self._sock.recv(1) + except socket.error as e: + if e.errno in _ERRNO_RETRY: + logger.debug("SOCKS reply not yet available.") + return errno.EINPROGRESS + + logger.debug("Connection failure: %s", e) + self._connecting = False + self._conn_err = e.errno + return e.errno + + if resp0 != "\x05": + self._connecting = False + raise error.SOCKSv5Error( + "Protocol error: server reply begins with 0x%02x, not 0x05" + % ord(resp0)) + + # We are now committed to receiving and processing the server + # response. + resp = self._recv_all(3) + if resp[0] != "\x00": + self._connecting = False + val = ord(resp[0]) + if val in socks5_errors: + self._conn_err = socks5_errors[val] + logger.debug("Connection failure at protocol level: %s", + os.strerror(self._conn_err)) + return self._conn_err + else: + raise error.SOCKSv5Error("Unrecognized SOCKSv5 error: %d" % val) - if resp[3] == "\x01": + # Read and discard the rest of the reply, which consists of an + # address type (1 byte), variable-length address (depending on the + # address type), and port number (2 bytes). + if resp[2] == "\x01": self._recv_all(4) - elif resp[3] == "\x03": + elif resp[2] == "\x03": length = self._recv_all(1) - self._recv_all(length) + self._recv_all(ord(length)) else: self._recv_all(16) + self._recv_all(2) - # Get port. + # We are now officially connected. + logger.debug("Now connected to %s:%d.", *self._peer_addr) + self._connected = True + return 0 + + def _maybe_finish_socks_handshake(self): + if self._connected: + return + if not self._connecting: + raise socket.error(errno.ENOTCONN, os.strerror(errno.ENOTCONN)) + + err = self._attempt_finish_socks_handshake() + if err: + # Callers of _this_ function expect EAGAIN, not EINPROGRESS. + if err in _ERRNO_RETRY: + raise socket.error(errno.EAGAIN, os.strerror(errno.EAGAIN)) + raise socket.error(err, os.strerror(err)) + + # All of these functions must be prepared to process the final + # message of the SOCKS handshake. + def send(self, *args): + self._maybe_finish_socks_handshake() + return self._sock.send(*args) + def sendall(self, *args): + self._maybe_finish_socks_handshake() + return self._sock.sendall(*args) + + def recv(self, *args): + self._maybe_finish_socks_handshake() + return self._sock.recv(*args) + def recv_into(self, *args): + self._maybe_finish_socks_handshake() + return self._sock.recv_into(*args) + + def makefile(self, *args): + # This one is a normal method on socket.socket. + self._maybe_finish_socks_handshake() + return super(_Torsocket, self).makefile(*args) + + # These sockets can only be used as client sockets. + def accept(self): raise NotImplementedError + def bind(self): raise NotImplementedError + def listen(self): raise NotImplementedError + + # These sockets can only be used as connected sockets. + def sendto(self, *a): raise NotImplementedError + def recvfrom(self, *a): raise NotImplementedError + def recvfrom_into(self, *a): raise NotImplementedError + + # Provide information about the ultimate destination, not the + # proxy server. On normal sockets, getpeername() works immediately + # after connect(), even if it returned EINPROGRESS. + def getpeername(self): + if not self._connecting: + raise socket.error(errno.ENOTCONN, os.strerror(errno.ENOTCONN)) + return self._peer_addr + + # Provide the pending connection error if appropriate. + def getsockopt(self, level, opt, *args): + if level == socket.SOL_SOCKET and opt == socket.SO_ERROR: + if self._connecting: + err = self._attempt_finish_socks_handshake() + if err == errno.EINPROGRESS: + return 0 # there's no pending connection error yet + + if self._conn_err is not None: + err = self._conn_err + self._conn_err = None + return err + + return self._sock.getsockopt(level, opt, *args) + + +def torsocket(family=socket.AF_INET, type=socket.SOCK_STREAM, + proto=0, _sock=None): + """ + Factory function usable as a monkey-patch for socket.socket. + """ - self._recv_all(2) + # Pass through local sockets. + if family in _LOCAL_SOCKETS: + return orig_socket(family, type, proto, _sock) + + # Tor only supports AF_INET sockets. + if family != socket.AF_INET: + raise socket.error(errno.EAFNOSUPPORT, os.strerror(errno.EAFNOSUPPORT)) + + # Tor only supports SOCK_STREAM sockets. + if type != socket.SOCK_STREAM: + raise socket.error(errno.ESOCKTNOSUPPORT, + os.strerror(errno.ESOCKTNOSUPPORT)) + + # Acceptable values for PROTO are 0 and IPPROTO_TCP. + if proto not in (0, socket.IPPROTO_TCP): + raise socket.error(errno.EPROTONOSUPPORT, + os.strerror(errno.EPROTONOSUPPORT)) + + return _Torsocket(family, type, proto, _sock) + +class MonkeyPatchedSocket(object): + """ + Context manager which monkey-patches socket.socket with + the above torsocket(). It also sets up this module's + global state. + """ + def __init__(self, queue, circ_id, socks_port, socks_addr="127.0.0.1"): + self._queue = queue + self._circ_id = circ_id + self._socks_addr = socks_addr + self._socks_port = socks_port + + self._orig_queue = None + self._orig_circ_id = None + self._orig_proxy_addr = None + self._orig_proxy_port = None + self._orig_socket = None + + def __enter__(self): + global queue, circ_id, proxy_addr, proxy_port, socket, torsocket + + # Make sure __exit__ can put everything back just as it was. + self._orig_queue = queue + self._orig_circ_id = circ_id + self._orig_proxy_addr = proxy_addr + self._orig_proxy_port = proxy_port + self._orig_socket = socket.socket + + queue = self._queue + circ_id = self._circ_id + proxy_addr = self._socks_addr + proxy_port = self._socks_port + socket.socket = torsocket + + return self + + def __exit__(self, *dontcare): + global queue, circ_id, proxy_addr, proxy_port, socket + + queue = self._orig_queue + circ_id = self._orig_circ_id + proxy_addr = self._orig_proxy_addr + proxy_port = self._orig_proxy_port + socket.socket = self._orig_socket + + return False