Skip to content

Commit

Permalink
Prevent data being truncated over persistent connection socket (ansib…
Browse files Browse the repository at this point in the history
…le#43885)

* Change how data is sent to the persistent connection socket.

We can't rely on readline(), so send the size of the data first. We can
then read that many bytes from the stream on the recieving end.

* Set pty to noncanonical mode before sending

* Now that we send data length, we don't need a sentinel anymore

* Copy socket changes to persistent, too

* Use os.write instead of fdopen()ing and using that.

* Follow pickle with sha1sum of pickle

* Swap order of vars and init being passed to ansible-connection
  • Loading branch information
Qalthos authored Aug 10, 2018
1 parent 77bff99 commit f221105
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 60 deletions.
39 changes: 20 additions & 19 deletions bin/ansible-connection
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ except Exception:
pass

import fcntl
import hashlib
import os
import signal
import socket
Expand All @@ -36,6 +37,23 @@ from ansible.utils.display import Display
from ansible.utils.jsonrpc import JsonRpcServer


def read_stream(byte_stream):
size = int(byte_stream.readline().strip())

data = byte_stream.read(size)
if len(data) < size:
raise Exception("EOF found before data was complete")

data_hash = to_text(byte_stream.readline().strip())
if data_hash != hashlib.sha1(data).hexdigest():
raise Exception("Read {0} bytes, but data did not match checksum".format(size))

# restore escaped loose \r characters
data = data.replace(br'\r', b'\r')

return data


@contextmanager
def file_lock(lock_path):
"""
Expand Down Expand Up @@ -204,25 +222,8 @@ def main():

try:
# read the play context data via stdin, which means depickling it
cur_line = stdin.readline()
init_data = b''

while cur_line.strip() != b'#END_INIT#':
if cur_line == b'':
raise Exception("EOF found before init data was complete")
init_data += cur_line
cur_line = stdin.readline()

cur_line = stdin.readline()
vars_data = b''

while cur_line.strip() != b'#END_VARS#':
if cur_line == b'':
raise Exception("EOF found before vars data was complete")
vars_data += cur_line
cur_line = stdin.readline()
# restore escaped loose \r characters
vars_data = vars_data.replace(br'\r', b'\r')
vars_data = read_stream(stdin)
init_data = read_stream(stdin)

if PY3:
pc_data = cPickle.loads(init_data, encoding='bytes')
Expand Down
39 changes: 18 additions & 21 deletions lib/ansible/executor/task_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
import json
import subprocess
import sys
import termios
import traceback

from ansible import constants as C
from ansible.errors import AnsibleError, AnsibleParserError, AnsibleUndefinedVariable, AnsibleConnectionFailure, AnsibleActionFail, AnsibleActionSkip
from ansible.executor.task_result import TaskResult
from ansible.module_utils.six import iteritems, string_types, binary_type
from ansible.module_utils.six.moves import cPickle
from ansible.module_utils._text import to_text, to_native
from ansible.module_utils.connection import write_to_file_descriptor
from ansible.playbook.conditional import Conditional
from ansible.playbook.task import Task
from ansible.template import Templar
Expand Down Expand Up @@ -920,28 +921,24 @@ def find_file_in_path(filename):
[python, find_file_in_path('ansible-connection'), to_text(os.getppid())],
stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdin = os.fdopen(master, 'wb', 0)
os.close(slave)

# Need to force a protocol that is compatible with both py2 and py3.
# That would be protocol=2 or less.
# Also need to force a protocol that excludes certain control chars as
# stdin in this case is a pty and control chars will cause problems.
# that means only protocol=0 will work.
src = cPickle.dumps(self._play_context.serialize(), protocol=0)
stdin.write(src)
stdin.write(b'\n#END_INIT#\n')

src = cPickle.dumps(variables, protocol=0)
# remaining \r fail to round-trip the socket
src = src.replace(b'\r', br'\r')
stdin.write(src)
stdin.write(b'\n#END_VARS#\n')

stdin.flush()

(stdout, stderr) = p.communicate()
stdin.close()
# We need to set the pty into noncanonical mode. This ensures that we
# can receive lines longer than 4095 characters (plus newline) without
# truncating.
old = termios.tcgetattr(master)
new = termios.tcgetattr(master)
new[3] = new[3] & ~termios.ICANON

try:
termios.tcsetattr(master, termios.TCSANOW, new)
write_to_file_descriptor(master, variables)
write_to_file_descriptor(master, self._play_context.serialize())

(stdout, stderr) = p.communicate()
finally:
termios.tcsetattr(master, termios.TCSANOW, old)
os.close(master)

if p.returncode == 0:
result = json.loads(to_text(stdout, errors='surrogate_then_replace'))
Expand Down
25 changes: 25 additions & 0 deletions lib/ansible/module_utils/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
# USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import hashlib
import json
import socket
import struct
Expand All @@ -36,6 +37,30 @@
from functools import partial
from ansible.module_utils._text import to_bytes, to_text
from ansible.module_utils.six import iteritems
from ansible.module_utils.six.moves import cPickle


def write_to_file_descriptor(fd, obj):
"""Handles making sure all data is properly written to file descriptor fd.
In particular, that data is encoded in a character stream-friendly way and
that all data gets written before returning.
"""
# Need to force a protocol that is compatible with both py2 and py3.
# That would be protocol=2 or less.
# Also need to force a protocol that excludes certain control chars as
# stdin in this case is a pty and control chars will cause problems.
# that means only protocol=0 will work.
src = cPickle.dumps(obj, protocol=0)

# raw \r characters will not survive pty round-trip
# They should be rehydrated on the receiving end
src = src.replace(b'\r', br'\r')
data_hash = to_bytes(hashlib.sha1(src).hexdigest())

os.write(fd, b'%d\n' % len(src))
os.write(fd, src)
os.write(fd, b'%s\n' % data_hash)


def send_data(s, data):
Expand Down
38 changes: 18 additions & 20 deletions lib/ansible/plugins/connection/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@
import json
import subprocess
import sys
import termios

from ansible import constants as C
from ansible.plugins.connection import ConnectionBase
from ansible.module_utils._text import to_text
from ansible.module_utils.six.moves import cPickle
from ansible.module_utils.connection import Connection as SocketConnection
from ansible.module_utils.connection import Connection as SocketConnection, write_to_file_descriptor
from ansible.errors import AnsibleError

try:
Expand Down Expand Up @@ -109,26 +109,24 @@ def find_file_in_path(filename):
[python, find_file_in_path('ansible-connection'), to_text(os.getppid())],
stdin=slave, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdin = os.fdopen(master, 'wb', 0)
os.close(slave)

# Need to force a protocol that is compatible with both py2 and py3.
# That would be protocol=2 or less.
# Also need to force a protocol that excludes certain control chars as
# stdin in this case is a pty and control chars will cause problems.
# that means only protocol=0 will work.
src = cPickle.dumps(self._play_context.serialize(), protocol=0)
stdin.write(src)
stdin.write(b'\n#END_INIT#\n')

src = cPickle.dumps({'ansible_command_timeout': self.get_option('persistent_command_timeout')}, protocol=0)
stdin.write(src)
stdin.write(b'\n#END_VARS#\n')

stdin.flush()

(stdout, stderr) = p.communicate()
stdin.close()
# We need to set the pty into noncanonical mode. This ensures that we
# can receive lines longer than 4095 characters (plus newline) without
# truncating.
old = termios.tcgetattr(master)
new = termios.tcgetattr(master)
new[3] = new[3] & ~termios.ICANON

try:
termios.tcsetattr(master, termios.TCSANOW, new)
write_to_file_descriptor(master, {'ansible_command_timeout': self.get_option('persistent_command_timeout')})
write_to_file_descriptor(master, self._play_context.serialize())

(stdout, stderr) = p.communicate()
finally:
termios.tcsetattr(master, termios.TCSANOW, old)
os.close(master)

if p.returncode == 0:
result = json.loads(to_text(stdout, errors='surrogate_then_replace'))
Expand Down

0 comments on commit f221105

Please sign in to comment.