Skip to content

Commit

Permalink
using explicit type casting required for python3
Browse files Browse the repository at this point in the history
* using the str wrapper cstr and bytes wrapper cbytes
  is required to support both, python 2+3
  • Loading branch information
anthraxx committed Oct 17, 2013
1 parent 87a2c0a commit 74f052a
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 18 deletions.
60 changes: 46 additions & 14 deletions shellnoob.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import urllib
from tempfile import mktemp, NamedTemporaryFile
from subprocess import call, Popen, PIPE
import binascii
PY2 = sys.version_info.major == 2

try:
Expand All @@ -45,6 +46,11 @@

if PY2:
input = raw_input
cbytes = lambda source, encoding: bytes(source)
cstr = lambda source, encoding: str(source)
else:
cbytes = lambda source, encoding: bytes(source, encoding)
cstr = lambda source, encoding: str(source, encoding)

######################
### main functions ###
Expand Down Expand Up @@ -338,6 +344,7 @@ def get_comment_as_char(self, kernel=None, hardware=None):
######################

def do_resolve_syscall(self, syscall, kernel=None, hardware=None):
global cstr
kernel = kernel if kernel is not None else self.kernel
hardware = hardware if hardware is not None else self.hardware

Expand All @@ -362,7 +369,7 @@ def do_resolve_syscall(self, syscall, kernel=None, hardware=None):
output, error = p.communicate()
retval = p.returncode
if retval == 0:
print('%s ~> %s' % (platform, output))
print('%s ~> %s' % (platform, cstr(output, "utf-8")))
else:
print('ERROR: reval %s while resolving syscall %s' % (retval, syscall), file=sys.stderr)

Expand Down Expand Up @@ -390,6 +397,7 @@ def do_resolve_const(self, const):


def do_resolve_errno(self, errno):
global cstr
includes = ['string.h']

body = 'printf("%%s", strerror(%s)); return 0;' % (errno)
Expand All @@ -404,12 +412,13 @@ def do_resolve_errno(self, errno):
output, error = p.communicate()
retval = p.returncode
if retval == 0:
print('%s ~> %s' % (errno, output))
print('%s ~> %s' % (errno, cstr(output, "utf-8")))
else:
print('ERROR: reval %s while resolving errno %s' % (retval, errno), file=sys.stderr)


def do_interactive_mode(self, args):
global cbytes
asm_to_opcode_flag = None
if '--to-opcode' in args:
asm_to_opcode_flag = True
Expand Down Expand Up @@ -457,7 +466,7 @@ def do_interactive_mode(self, args):
try:
_hex = _hex.replace(' ','').strip(' \t\n')
asm = self.hex_to_pretty(_hex)
print('%s ~> %s' % (cbytes(_hex), asm))
print('%s ~> %s' % (cbytes(_hex, 'utf-8'), asm))
except Exception as e:
print('ERROR: %s' % e, file=sys.stderr)
if self.verbose >= 3:
Expand All @@ -467,6 +476,7 @@ def do_interactive_mode(self, args):


def do_conversion(self, input_fp, output_fp, input_fmt, output_fmt):
global cbytes
if self.verbose >= 0:
if input_fmt == '-':
msg = 'Converting from stdin (%s) ' % input_fmt
Expand Down Expand Up @@ -494,6 +504,8 @@ def do_conversion(self, input_fp, output_fp, input_fmt, output_fmt):
print('ERROR: conversion mode "%s" is not supported.' % conv_func_name, file=sys.stderr)
sys.exit(2)

if not isinstance(_output, bytes):
_output = cbytes(_output, 'utf-8')
# writing the output
if output_fp == '-':
sys.stdout.write(_output)
Expand Down Expand Up @@ -629,22 +641,23 @@ def do_exe_patch(self, exe_fp, data, file_offset=None, vm_address=None, replace=
###################

def asm_to_hex(self, asm, with_breakpoint=None):
global cstr
if self.verbose >= 3: print('IN asm_to_hex', file=sys.stderr)
with_breakpoint = with_breakpoint if with_breakpoint is not None else self.with_breakpoint

obj = self.asm_to_obj(asm, with_breakpoint)
_hex = self.obj_to_hex(obj, with_breakpoint=False)

if self.verbose >= 3: print('OUT asm_to_hex', file=sys.stderr)
return _hex
return cstr(_hex, 'utf-8')

def bin_to_hex(self, _bin, with_breakpoint=None):
global cbytes
if self.verbose >= 3: print('IN bin_to_hex', file=sys.stderr)
with_breakpoint = with_breakpoint if with_breakpoint is not None else self.with_breakpoint

prepend = self.get_breakpoint_hex() if with_breakpoint else ''
if self.verbose >= 3: print('OUT bin_to_hex', file=sys.stderr)
return prepend + str(binascii.b2a_hex(_bin))
return cbytes(prepend, 'utf-8') + binascii.hexlify(_bin)

def obj_to_hex(self, obj, with_breakpoint=None):
if self.verbose >= 3: print('IN obj_to_hex', file=sys.stderr)
Expand Down Expand Up @@ -740,6 +753,8 @@ def hex_to_obj(self, _hex, with_breakpoint=None):
if self.verbose >= 3: print('IN hex_to_obj', file=sys.stderr)
with_breakpoint = with_breakpoint if with_breakpoint is not None else self.with_breakpoint

if not isinstance(_hex, str):
_hex = cstr(_hex, 'utf-8')
if len(_hex) != 0 and _hex.endswith('\n'):
_hex = _hex.rstrip('\n')
print('Warning: stripped a \'\\n\' at the end of the hex', file=sys.stderr)
Expand All @@ -765,9 +780,12 @@ def hex_to_exe(self, _hex, with_breakpoint=None):
return exe

def hex_to_bin(self, _hex, with_breakpoint=None):
global cstr
if self.verbose >= 3: print('IN hex_to_bin', file=sys.stderr)
with_breakpoint = with_breakpoint if with_breakpoint is not None else self.with_breakpoint

if not isinstance(_hex, str):
_hex = cstr(_hex, 'utf-8')
if len(_hex) != 0 and _hex.endswith('\n'):
_hex = _hex.rstrip('\n')
print('Warning: stripped a \'\\n\' at the end of the hex', file=sys.stderr)
Expand All @@ -778,12 +796,14 @@ def hex_to_bin(self, _hex, with_breakpoint=None):
_hex = prepend + _hex

if self.verbose >= 3: print('OUT hex_to_bin', file=sys.stderr)
return _hex.decode('hex')
return binascii.unhexlify(_hex)

def hex_to_c(self, _hex, with_breakpoint=None):
if self.verbose >= 3: print('IN hex_to_c', file=sys.stderr)
with_breakpoint = with_breakpoint if with_breakpoint is not None else self.with_breakpoint

if not isinstance(_hex, str):
_hex = cstr(_hex, 'utf-8')
if len(_hex) != 0 and _hex.endswith('\n'):
_hex = _hex.rstrip('\n')
print('Warning: stripped a \'\\n\' at the end of the hex', file=sys.stderr)
Expand All @@ -802,9 +822,12 @@ def hex_to_c(self, _hex, with_breakpoint=None):
return out

def hex_to_python(self, _hex, with_breakpoint=None):
global cstr
if self.verbose >= 3: print('IN hex_to_python', file=sys.stderr)
with_breakpoint = with_breakpoint if with_breakpoint is not None else self.with_breakpoint

if not isinstance(_hex, str):
_hex = cstr(_hex, 'utf-8')
if len(_hex) != 0 and _hex.endswith('\n'):
_hex = _hex.rstrip('\n')
print('Warning: stripped a \'\\n\' at the end of the hex', file=sys.stderr)
Expand Down Expand Up @@ -904,17 +927,18 @@ def obj_to_pretty(self, obj, with_breakpoint=None):
#########################

def asm_to_obj(self, asm, with_breakpoint=None):
global cstr
if self.verbose >= 3: print('IN asm_to_obj', file=sys.stderr)
with_breakpoint = with_breakpoint if with_breakpoint is not None else self.with_breakpoint

asm += '\n' # as complains if the asm doesn't end with newline

if isinstance(asm, bytes):
asm = cstr(asm, 'utf-8')
prepend = self.hex_to_asm_bytes(self.get_breakpoint_hex()) if with_breakpoint else ''
asm = prepend + asm

asm = prepend + asm + '\n'
tmp_asm_f = NamedTemporaryFile(delete=False)
tmp_asm_fp = tmp_asm_f.name
tmp_asm_f.write(asm)
tmp_asm_f.write(asm.encode("utf-8"))
tmp_asm_f.close()

tmp_obj_fp = mktemp()
Expand Down Expand Up @@ -982,7 +1006,7 @@ def obj_to_asm(self, obj, with_breakpoint=None):
_hex = m.group(1).replace(' ', '').strip(' \t\n')
help_asm = self.hex_to_asm_bytes(_hex).rstrip('\n')
try:
_ascii = '.ascii "%s"' % _hex.decode('hex').decode('ascii')
_ascii = '.ascii "%s"' % _hex
_ascii = _ascii.strip(' \t\n')
except UnicodeDecodeError:
_ascii = ''
Expand Down Expand Up @@ -1052,9 +1076,12 @@ def obj_to_exe(self, obj, with_breakpoint=None):
return exe

def hex_to_safeasm(self, _hex, with_breakpoint=None):
global cstr
if self.verbose >= 3: print('IN hex_to_safeasm', file=sys.stderr)
with_breakpoint = with_breakpoint if with_breakpoint is not None else self.with_breakpoint

if not isinstance(_hex, str):
_hex = cstr(_hex, 'utf-8')
if len(_hex) != 0 and _hex.endswith('\n'):
_hex = _hex.rstrip('\n')
print('Warning: stripped a \'\\n\' at the end of the hex', file=sys.stderr)
Expand All @@ -1081,12 +1108,15 @@ def hex_to_completec(self, _hex, with_breakpoint=None):
return completec

def c_to_exe(self, c, with_breakpoint=None):
global cbytes
# NOTE assumption: the input is "compileable C"
if self.verbose >= 3: print('IN c_to_exe', file=sys.stderr)

if with_breakpoint:
raise Exception('the with_breakpoint option is NOT supported in c_to_exe')

if not isinstance(c, bytes):
c = cbytes(c, 'utf-8')
tmp_c_f = NamedTemporaryFile(suffix='.c', delete=False)
tmp_c_fp = tmp_c_f.name
tmp_c_f.write(c)
Expand Down Expand Up @@ -1123,7 +1153,7 @@ def hex_to_inss(self, _hex):
inss = filter(lambda i:i.strip(' \t'), inss)
inss = map(lambda i:i.split('#')[0], inss)
inss = map(lambda i:i.strip(' \t'), inss)
return inss
return list(inss)

def inss_to_asm(self, inss):
out = '\n'.join(inss)
Expand All @@ -1149,6 +1179,7 @@ def hex_to_asm_bytes(self, _hex):
return asm

def include_and_body_to_exe_fp(self, includes, body):
global cbytes
std_includes = set(('stdio.h', 'stdlib.h', 'errno.h'))
includes = set(includes)
includes.update(std_includes)
Expand All @@ -1165,7 +1196,7 @@ def include_and_body_to_exe_fp(self, includes, body):
tmp_exe_fp = mktemp()

with open(tmp_c_fp, 'wb') as f:
f.write(c_prog)
f.write(cbytes(c_prog, 'utf-8'))

cmd = 'gcc %s -o %s' % (tmp_c_fp, tmp_exe_fp)
retval = self.exec_cmd(cmd, 'include_and_body_to_exe_fp')
Expand All @@ -1184,6 +1215,7 @@ def get_start_address(self, exe_fp):
_out, _err = p.communicate()
assert p.returncode == 0

_out = cstr(_out, 'utf-8')
for line in _out.split('\n'):
line = line.strip(' \t\n')
m = re.search('^start address (0x[0-9a-f]+)$', line)
Expand Down
8 changes: 4 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
except ImportError:
pass

from shellnoob import ShellNoob
from shellnoob import ShellNoob, cstr, cbytes

GREEN = '\033[92m'
RED = '\033[91m'
Expand Down Expand Up @@ -244,10 +244,10 @@ def run_with_args_input(_input='', args=''):
cmd = '%s %s' % (SHELLNOOB_FP, args)
print 'Launching: %s (with input)' % (cmd)
p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=PIPE)
stdout, stderr = p.communicate(input=_input)
stdout, stderr = p.communicate(input=cbytes(_input, 'utf-8'))
retval = p.returncode

return stdout, stderr, retval
return cstr(stdout, 'utf-8'), cstr(stderr, 'utf-8'), int(retval)


def run_with_args(args=''):
Expand All @@ -258,7 +258,7 @@ def run_with_args(args=''):
stdout, stderr = p.communicate(input='')
retval = p.returncode

return stdout, stderr, retval
return cstr(stdout, 'utf-8'), cstr(stderr, 'utf-8'), int(retval)

def run_all_tests():
kernel, hardware = ShellNoob.get_kernel(), ShellNoob.get_hardware()
Expand Down

0 comments on commit 74f052a

Please sign in to comment.