Skip to content

Commit

Permalink
Added missing host key policy option
Browse files Browse the repository at this point in the history
  • Loading branch information
huashengdun committed Mar 14, 2018
1 parent 96eae01 commit 6ee1db2
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,6 @@ target/

# temporary file
*.swp

# known_hosts file
known_hosts
62 changes: 50 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,15 @@
define('address', default='127.0.0.1', help='listen address')
define('port', default=8888, help='listen port', type=int)
define('debug', default=False, help='debug mode', type=bool)
define('policy', default='reject',
help='missing host key polilcy, reject|autoadd|warning')


BUF_SIZE = 1024
DELAY = 3
base_dir = os.path.dirname(__file__)
workers = {}


def recycle(worker):
if worker.handler:
return
logging.debug('Recycling worker {}'.format(worker.id))
workers.pop(worker.id, None)
worker.close()


class Worker(object):
def __init__(self, ssh, chan, dst_addr):
self.loop = IOLoop.current()
Expand Down Expand Up @@ -204,8 +197,8 @@ def get_client_addr(self):

def ssh_connect(self):
ssh = paramiko.SSHClient()
ssh.load_system_host_keys()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.load_host_keys(self.settings['host_file'])
ssh.set_missing_host_key_policy(self.settings['policy'])
args = self.get_args()
dst_addr = (args[0], args[1])
logging.info('Connecting to {}:{}'.format(*dst_addr))
Expand All @@ -215,6 +208,8 @@ def ssh_connect(self):
raise ValueError('Unable to connect to {}:{}'.format(*dst_addr))
except paramiko.BadAuthenticationType:
raise ValueError('Authentication failed.')
except paramiko.BadHostKeyException:
raise ValueError('Bad host key.')
chan = ssh.invoke_shell(term='xterm')
chan.setblocking(0)
worker = Worker(ssh, chan, dst_addr)
Expand Down Expand Up @@ -278,7 +273,46 @@ def on_close(self):
worker.close()


def recycle(worker):
if worker.handler:
return
logging.debug('Recycling worker {}'.format(worker.id))
workers.pop(worker.id, None)
worker.close()


def get_host_keys(path):
if os.path.exists(path) and os.path.isfile(path):
return paramiko.hostkeys.HostKeys(filename=path)


def create_host_file(host_file):
host_keys = get_host_keys(host_file)
if not host_keys:
host_keys = get_host_keys(os.path.expanduser("~/.ssh/known_hosts"))
host_keys.save(host_file)


def get_policy_class(policy):
origin_policy = policy
policy = policy.lower()
if not policy.endswith('policy'):
policy += 'policy'

dic = {k.lower(): v for k, v in vars(paramiko.client).items()}

try:
cls = dic[policy]
except KeyError:
raise ValueError('Unknown policy {!r}'.format(origin_policy))
return cls


def main():
base_dir = os.path.dirname(__file__)
host_file = os.path.join(base_dir, 'known_hosts')
create_host_file(host_file)

settings = {
'template_path': os.path.join(base_dir, 'templates'),
'static_path': os.path.join(base_dir, 'static'),
Expand All @@ -292,7 +326,11 @@ def main():
]

parse_command_line()
settings.update(debug=options.debug)
settings.update(
debug=options.debug,
host_file=host_file,
policy=get_policy_class(options.policy)()
)
app = tornado.web.Application(handlers, **settings)
app.listen(options.port, options.address)
logging.info('Listening on {}:{}'.format(options.address, options.port))
Expand Down

0 comments on commit 6ee1db2

Please sign in to comment.