Skip to content

Commit

Permalink
Refactor run_sploit() and main()
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Apr 14, 2018
1 parent 803161c commit 9464671
Showing 1 changed file with 41 additions and 35 deletions.
76 changes: 41 additions & 35 deletions client/start_sploit.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,21 @@ def parse_args():
return parser.parse_args()


def parse_distribute_argument(value):
if value is not None:
match = re.fullmatch(r'(\d+)/(\d+)', value)
def fix_args(args):
if '://' not in args.server_url:
args.server_url = 'http://' + args.server_url

if args.distribute is not None:
valid = False
match = re.fullmatch(r'(\d+)/(\d+)', args.distribute)
if match is not None:
k, n = (int(match.group(1)), int(match.group(2)))
if n >= 2 and 1 <= k <= n:
return k, n
raise ValueError('Wrong syntax for --distribute, use --distribute K/N (N >= 2, 1 <= K <= N)')
return None
args.distribute = k, n
valid = True

if not valid:
raise ValueError('Wrong syntax for --distribute, use --distribute K/N (N >= 2, 1 <= K <= N)')


SCRIPT_EXTENSIONS = ['.pl', '.py', '.rb']
Expand Down Expand Up @@ -340,29 +346,33 @@ def register_stop(self, instance_id, was_killed):
# TODO: Exclude lock from the class, rename InstanceManager to InstanceStorage


def launch_sploit(args, team_name, team_addr, attack_no, flag_format):
if exit_event.is_set():
return

# For sploits written in Python, this env variable forces the interpreter to flush
# stdout and stderr after each newline. Note that this is not default behavior
# if the sploit's output is redirected to a pipe.
env = os.environ.copy()
env['PYTHONUNBUFFERED'] = '1'

command = [os.path.abspath(args.sploit)]
if team_addr is not None:
command.append(team_addr)
need_close_fds = (os.name != 'nt')
proc = subprocess.Popen(command,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
bufsize=1, close_fds=need_close_fds, env=env)
threading.Thread(target=lambda: process_sploit_output(
proc.stdout, args, team_name, flag_format, attack_no)).start()

return proc, instance_manager.register_start(proc)


def run_sploit(args, team_name, team_addr, attack_no, max_runtime, flag_format):
try:
with instance_manager.lock:
if exit_event.is_set():
return

# For sploits written in Python, this env variable forces the interpreter to flush
# stdout and stderr after each newline. Note that this is not default behavior
# if the sploit's output is redirected to a pipe.
env = os.environ.copy()
env['PYTHONUNBUFFERED'] = '1'

command = [os.path.abspath(args.sploit)]
if team_addr is not None:
command.append(team_addr)
need_close_fds = (os.name != 'nt')
proc = subprocess.Popen(command,
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
bufsize=1, close_fds=need_close_fds, env=env)
threading.Thread(target=lambda: process_sploit_output(
proc.stdout, args, team_name, flag_format, attack_no)).start()

instance_id = instance_manager.register_start(proc)
proc, instance_id = launch_sploit(args, team_name, team_addr, attack_no, flag_format)
except Exception as e:
if isinstance(e, FileNotFoundError):
logging.error('Sploit file or the interpreter for it not found: {}'.format(repr(e)))
Expand Down Expand Up @@ -411,12 +421,12 @@ def show_time_limit_info(args, config, max_runtime, attack_no):
PRINTED_TEAM_NAMES = 5


def get_target_teams(args, teams, distribute, attack_no):
def get_target_teams(args, teams, attack_no):
if args.not_per_team:
return {'*': None}

if distribute is not None:
k, n = distribute
if args.distribute is not None:
k, n = args.distribute
teams = {name: addr for name, addr in teams.items()
if binascii.crc32(addr.encode()) % n == k - 1}

Expand All @@ -435,11 +445,7 @@ def get_target_teams(args, teams, distribute, attack_no):

def main(args):
try:
if '://' not in args.server_url:
args.server_url = 'http://' + args.server_url

distribute = parse_distribute_argument(args.distribute)

fix_args(args)
check_sploit(args.sploit)
except (ValueError, InvalidSploitError) as e:
logging.critical(str(e))
Expand All @@ -461,7 +467,7 @@ def main(args):
if attack_no == 1:
return
logging.info('Using the old config')
teams = get_target_teams(args, config['TEAMS'], distribute, attack_no)
teams = get_target_teams(args, config['TEAMS'], attack_no)
if not teams:
if attack_no == 1:
return
Expand Down

0 comments on commit 9464671

Please sign in to comment.