Skip to content

Commit

Permalink
Fix for persistent connection plugin on Python3 (ansible#24431)
Browse files Browse the repository at this point in the history
Fix for persistent connection plugin on Python3.  Note that fixes are also needed to each terminal plugin.  This PR only fixes the ios terminal (as proof that this approach is workable.)  Future PRs can address the other terminal types.

* On Python3, pickle needs to work with byte strings, not text strings.
* Set the pickle protocol version to 0 because we're using a pty to feed data to the connection plugin.  A pty can't have control characters.  So we have to send ascii only.  That means
only using protocol=0 for pickling the data.
* ansible-connection isn't being used with py3 in the bug but it needs
several changes to work with python3.
* In python3, closing the pty too early causes no data to be sent.  So
leave stdin open until after we finish with the ansible-connection
process.
* Fix typo using traceback.format_exc()
* Cleanup unnecessary StringIO, BytesIO, and to_bytes calls
* Modify the network_cli and terminal plugins for py3 compat.  Lots of mixing of text and byte strings that needs to be straightened out to be compatible with python3
* Documentation for the bytes<=>text strategy for terminal plugins
* Update unittests for more bytes-oriented internals

Fixes ansible#24355
  • Loading branch information
abadger authored May 12, 2017
1 parent e539726 commit d834412
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 111 deletions.
58 changes: 30 additions & 28 deletions bin/ansible-connection
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ from io import BytesIO

from ansible import constants as C
from ansible.module_utils._text import to_bytes, to_native
from ansible.module_utils.six.moves import cPickle, StringIO
from ansible.module_utils.six import PY3
from ansible.module_utils.six.moves import cPickle
from ansible.playbook.play_context import PlayContext
from ansible.plugins import connection_loader
from ansible.utils.path import unfrackpath, makedirs_safe
Expand Down Expand Up @@ -73,11 +74,11 @@ def do_fork():
sys.exit(0)

if C.DEFAULT_LOG_PATH != '':
out_file = file(C.DEFAULT_LOG_PATH, 'a+')
err_file = file(C.DEFAULT_LOG_PATH, 'a+', 0)
out_file = open(C.DEFAULT_LOG_PATH, 'ab+')
err_file = open(C.DEFAULT_LOG_PATH, 'ab+', 0)
else:
out_file = file('/dev/null', 'a+')
err_file = file('/dev/null', 'a+', 0)
out_file = open('/dev/null', 'ab+')
err_file = open('/dev/null', 'ab+', 0)

os.dup2(out_file.fileno(), sys.stdout.fileno())
os.dup2(err_file.fileno(), sys.stderr.fileno())
Expand All @@ -90,7 +91,7 @@ def do_fork():
sys.exit(1)

def send_data(s, data):
packed_len = struct.pack('!Q',len(data))
packed_len = struct.pack('!Q', len(data))
return s.sendall(packed_len + data)

def recv_data(s):
Expand All @@ -101,7 +102,7 @@ def recv_data(s):
if not d:
return None
data += d
data_len = struct.unpack('!Q',data[:header_len])[0]
data_len = struct.unpack('!Q', data[:header_len])[0]
data = data[header_len:]
while len(data) < data_len:
d = s.recv(data_len - len(data))
Expand Down Expand Up @@ -211,11 +212,9 @@ class Server():
pass
elif data.startswith(b'CONTEXT: '):
display.display("socket operation is CONTEXT", log_only=True)
pc_data = data.split(b'CONTEXT: ')[1]
pc_data = data.split(b'CONTEXT: ', 1)[1]

src = StringIO(pc_data)
pc_data = cPickle.load(src)
src.close()
pc_data = cPickle.loads(pc_data)

pc = PlayContext()
pc.deserialize(pc_data)
Expand All @@ -234,12 +233,12 @@ class Server():

display.display("socket operation completed with rc %s" % rc, log_only=True)

send_data(s, to_bytes(str(rc)))
send_data(s, to_bytes(rc))
send_data(s, to_bytes(stdout))
send_data(s, to_bytes(stderr))
s.close()
except Exception as e:
display.display(traceback.format_exec(), log_only=True)
display.display(traceback.format_exc(), log_only=True)
finally:
# when done, close the connection properly and cleanup
# the socket file so it can be recreated
Expand All @@ -254,21 +253,25 @@ class Server():
os.remove(self.path)

def main():
# Need stdin as a byte stream
if PY3:
stdin = sys.stdin.buffer
else:
stdin = sys.stdin

try:
# read the play context data via stdin, which means depickling it
# FIXME: as noted above, we will probably need to deserialize the
# connection loader here as well at some point, otherwise this
# won't find role- or playbook-based connection plugins
cur_line = sys.stdin.readline()
init_data = ''
while cur_line.strip() != '#END_INIT#':
if cur_line == '':
raise Exception("EOL found before init data was complete")
cur_line = stdin.readline()
init_data = b''
while cur_line.strip() != b'#END_INIT#':
if cur_line == b'':
raise Exception("EOF found before init data was complete")
init_data += cur_line
cur_line = sys.stdin.readline()
src = BytesIO(to_bytes(init_data))
pc_data = cPickle.load(src)
cur_line = stdin.readline()
pc_data = cPickle.loads(init_data)

pc = PlayContext()
pc.deserialize(pc_data)
Expand Down Expand Up @@ -319,10 +322,10 @@ def main():
# the connection will timeout here. Need to make this more resilient.
rc = 0
while rc == 0:
data = sys.stdin.readline()
if data == '':
data = stdin.readline()
if data == b'':
break
if data.strip() == '':
if data.strip() == b'':
continue
sf = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
attempts = 1
Expand All @@ -342,11 +345,10 @@ def main():

# send the play_context back into the connection so the connection
# can handle any privilege escalation activities
pc_data = 'CONTEXT: %s' % src.getvalue()
send_data(sf, to_bytes(pc_data))
src.close()
pc_data = b'CONTEXT: %s' % init_data
send_data(sf, pc_data)

send_data(sf, to_bytes(data.strip()))
send_data(sf, data.strip())

rc = int(recv_data(sf), 10)
stdout = recv_data(sf)
Expand Down
60 changes: 32 additions & 28 deletions lib/ansible/plugins/connection/network_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,18 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

import re
import socket
import json
import logging
import re
import signal
import datetime
import socket
import traceback
import logging
from collections import Sequence

from ansible import constants as C
from ansible.errors import AnsibleConnectionFailure
from ansible.module_utils.six.moves import StringIO
from ansible.module_utils.six import BytesIO, binary_type, text_type
from ansible.module_utils._text import to_bytes, to_text
from ansible.plugins import terminal_loader
from ansible.plugins.connection import ensure_connect
from ansible.plugins.connection.paramiko_ssh import Connection as _Connection
Expand Down Expand Up @@ -113,7 +114,7 @@ def open_shell(self):
self._terminal.on_authorize(passwd=auth_pass)

display.display('shell successfully opened', log_only=True)
return (0, 'ok', '')
return (0, b'ok', b'')

def close(self):
display.display('closing connection', log_only=True)
Expand All @@ -131,11 +132,11 @@ def close_shell(self):
self._shell.close()
self._shell = None

return (0, 'ok', '')
return (0, b'ok', b'')

def receive(self, obj=None):
"""Handles receiving of output from command"""
recv = StringIO()
recv = BytesIO()
handled = False

self._matched_prompt = None
Expand All @@ -162,30 +163,30 @@ def send(self, obj):
try:
command = obj['command']
self._history.append(command)
self._shell.sendall('%s\r' % command)
self._shell.sendall(b'%s\r' % command)
if obj.get('sendonly'):
return
return self.receive(obj)
except (socket.timeout, AttributeError) as exc:
except (socket.timeout, AttributeError):
display.display(traceback.format_exc(), log_only=True)
raise AnsibleConnectionFailure("timeout trying to send command: %s" % command.strip())

def _strip(self, data):
"""Removes ANSI codes from device response"""
for regex in self._terminal.ansi_re:
data = regex.sub('', data)
data = regex.sub(b'', data)
return data

def _handle_prompt(self, resp, obj):
"""Matches the command prompt and responds"""
if not isinstance(obj['prompt'], list):
if isinstance(obj, (binary_type, text_type)) or not isinstance(obj['prompt'], Sequence):
obj['prompt'] = [obj['prompt']]
prompts = [re.compile(r, re.I) for r in obj['prompt']]
answer = obj['answer']
for regex in prompts:
match = regex.search(resp)
if match:
self._shell.sendall('%s\r' % answer)
self._shell.sendall(b'%s\r' % answer)
return True

def _sanitize(self, resp, obj=None):
Expand All @@ -196,7 +197,7 @@ def _sanitize(self, resp, obj=None):
if (command and line.startswith(command.strip())) or self._matched_prompt.strip() in line:
continue
cleaned.append(line)
return str("\n".join(cleaned)).strip()
return b"\n".join(cleaned).strip()

def _find_prompt(self, response):
"""Searches the buffered response for a matching command prompt"""
Expand Down Expand Up @@ -225,45 +226,48 @@ def alarm_handler(self, signum, frame):
def exec_command(self, cmd):
"""Executes the cmd on in the shell and returns the output
The method accepts two forms of cmd. The first form is as a
The method accepts two forms of cmd. The first form is as a byte
string that represents the command to be executed in the shell. The
second form is as a JSON string with additional keyword.
second form is as a utf8 JSON byte string with additional keywords.
Keywords supported for cmd:
* command - the command string to execute
* prompt - the expected prompt generated by executing command
* answer - the string to respond to the prompt with
* sendonly - bool to disable waiting for response
:arg cmd: the string that represents the command to be executed
which can be a single command or a json encoded string
:arg cmd: the byte string that represents the command to be executed
which can be a single command or a json encoded string.
:returns: a tuple of (return code, stdout, stderr). The return
code is an integer and stdout and stderr are strings
code is an integer and stdout and stderr are byte strings
"""
try:
obj = json.loads(cmd)
obj = json.loads(to_text(cmd, errors='surrogate_or_strict'))
obj = dict((k, to_bytes(v, errors='surrogate_or_strict', nonstring='passthru')) for k, v in obj.items())
except (ValueError, TypeError):
obj = {'command': str(cmd).strip()}
obj = {'command': to_bytes(cmd.strip(), errors='surrogate_or_strict')}

if obj['command'] == 'close_shell()':
if obj['command'] == b'close_shell()':
return self.close_shell()
elif obj['command'] == 'open_shell()':
elif obj['command'] == b'open_shell()':
return self.open_shell()
elif obj['command'] == 'prompt()':
return (0, self._matched_prompt, '')
elif obj['command'] == b'prompt()':
return (0, self._matched_prompt, b'')

try:
if self._shell is None:
self.open_shell()
except AnsibleConnectionFailure as exc:
return (1, '', str(exc))
# FIXME: Feels like we should raise this rather than return it
return (1, b'', to_bytes(exc))

try:
if not signal.getsignal(signal.SIGALRM):
signal.signal(signal.SIGALRM, self.alarm_handler)
signal.alarm(self._play_context.timeout)
out = self.send(obj)
signal.alarm(0)
return (0, out, '')
return (0, out, b'')
except (AnsibleConnectionFailure, ValueError) as exc:
return (1, '', str(exc))
# FIXME: Feels like we should raise this rather than return it
return (1, b'', to_bytes(exc))
16 changes: 10 additions & 6 deletions lib/ansible/plugins/connection/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import sys

from ansible.module_utils._text import to_bytes
from ansible.module_utils.six.moves import cPickle, StringIO
from ansible.module_utils.six.moves import cPickle
from ansible.plugins.connection import ConnectionBase

try:
Expand Down Expand Up @@ -52,16 +52,20 @@ def _do_it(self, action):
stdin = os.fdopen(master, 'wb', 0)
os.close(slave)

src = StringIO()
cPickle.dump(self._play_context.serialize(), src)
stdin.write(src.getvalue())
src.close()
# Need to force a protocol that is compatible with both py2 and py3.
# That would be protocol=2 or less.
# Also need to force a protocol that excludes certain control chars as
# stdin in this case is a pty and control chars will cause problems.
# that means only protocol=0 will work.
src = cPickle.dumps(self._play_context.serialize(), protocol=0)
stdin.write(src)

stdin.write(b'\n#END_INIT#\n')
stdin.write(to_bytes(action))
stdin.write(b'\n\n')
stdin.close()

(stdout, stderr) = p.communicate()
stdin.close()

return (p.returncode, stdout, stderr)

Expand Down
Loading

0 comments on commit d834412

Please sign in to comment.