Skip to content

Commit

Permalink
Delete bad API token files
Browse files Browse the repository at this point in the history
Under the following conditions and API token should be considered
invalid:

- The file is empty.
- We cannot deserialize the token from the file.
- The token exists but has no expiration date.
- The token exists but has expired.

All of these conditions necessitate deleting the token file. Otherwise
we should simply return an empty token.
  • Loading branch information
waynew authored and dwoz committed Sep 3, 2019
1 parent 11016ce commit 982ed3d
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 12 deletions.
27 changes: 18 additions & 9 deletions salt/auth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@

# Import salt libs
import salt.config
import salt.exceptions
import salt.loader
import salt.payload
import salt.transport.client
import salt.utils.args
import salt.utils.dictupdate
Expand All @@ -34,7 +36,6 @@
import salt.utils.user
import salt.utils.versions
import salt.utils.zeromq
import salt.payload

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -242,16 +243,24 @@ def get_tok(self, tok):
Return the name associated with the token, or False if the token is
not valid
'''
tdata = self.tokens["{0}.get_token".format(self.opts['eauth_tokens'])](self.opts, tok)
if not tdata:
return {}

rm_tok = False
if 'expire' not in tdata:
# invalid token, delete it!
tdata = {}
try:
tdata = self.tokens["{0}.get_token".format(self.opts['eauth_tokens'])](self.opts, tok)
except salt.exceptions.SaltDeserializationError:
log.warning("Failed to load token %r - removing broken/empty file.", tok)
rm_tok = True
if tdata.get('expire', '0') < time.time():
else:
if not tdata:
return {}
rm_tok = False

if tdata.get('expire', 0) < time.time():
# If expire isn't present in the token it's invalid and needs
# to be removed. Also, if it's present and has expired - in
# other words, the expiration is before right now, it should
# be removed.
rm_tok = True

if rm_tok:
self.rm_token(tok)

Expand Down
6 changes: 6 additions & 0 deletions salt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ class TokenAuthenticationError(SaltException):
'''


class SaltDeserializationError(SaltException):
'''
Thrown when salt cannot deserialize data.
'''


class AuthorizationError(SaltException):
'''
Thrown when runner or wheel execution fails due to permissions
Expand Down
10 changes: 8 additions & 2 deletions salt/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import salt.transport.frame
import salt.utils.immutabletypes as immutabletypes
import salt.utils.stringutils
from salt.exceptions import SaltReqTimeoutError
from salt.exceptions import SaltReqTimeoutError, SaltDeserializationError
from salt.utils.data import CaseInsensitiveDict

# Import third party libs
Expand Down Expand Up @@ -174,7 +174,13 @@ def ext_type_decoder(code, data):
)
log.debug('Msgpack deserialization failure on message: %s', msg)
gc.collect()
raise
raise six.raise_from(
SaltDeserializationError(
'Could not deserialize msgpack message.'
' See log for more info.'
),
exc,
)
finally:
gc.enable()
return ret
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Import pytohn libs
from __future__ import absolute_import, print_function, unicode_literals

import time

# Import Salt Testing libs
from tests.support.unit import TestCase, skipIf
from tests.support.mock import patch, call, NO_MOCK, NO_MOCK_REASON, MagicMock
Expand All @@ -14,6 +16,7 @@
import salt.master
from tests.support.case import ModuleCase
from salt import auth
from salt.exceptions import SaltDeserializationError
import salt.utils.platform


Expand All @@ -37,6 +40,72 @@ def setUp(self): # pylint: disable=W0221
self.addCleanup(patcher.stop)
self.lauth = auth.LoadAuth({}) # Load with empty opts

def test_get_tok_with_broken_file_will_remove_bad_token(self):
fake_get_token = MagicMock(side_effect=SaltDeserializationError('hi'))
patch_opts = patch.dict(self.lauth.opts, {'eauth_tokens': 'testfs'})
patch_get_token = patch.dict(
self.lauth.tokens,
{
'testfs.get_token': fake_get_token
},
)
mock_rm_token = MagicMock()
patch_rm_token = patch.object(self.lauth, 'rm_token', mock_rm_token)
with patch_opts, patch_get_token, patch_rm_token:
expected_token = 'fnord'
self.lauth.get_tok(expected_token)
mock_rm_token.assert_called_with(expected_token)

def test_get_tok_with_no_expiration_should_remove_bad_token(self):
fake_get_token = MagicMock(return_value={'no_expire_here': 'Nope'})
patch_opts = patch.dict(self.lauth.opts, {'eauth_tokens': 'testfs'})
patch_get_token = patch.dict(
self.lauth.tokens,
{
'testfs.get_token': fake_get_token
},
)
mock_rm_token = MagicMock()
patch_rm_token = patch.object(self.lauth, 'rm_token', mock_rm_token)
with patch_opts, patch_get_token, patch_rm_token:
expected_token = 'fnord'
self.lauth.get_tok(expected_token)
mock_rm_token.assert_called_with(expected_token)

def test_get_tok_with_expire_before_current_time_should_remove_token(self):
fake_get_token = MagicMock(return_value={'expire': time.time()-1})
patch_opts = patch.dict(self.lauth.opts, {'eauth_tokens': 'testfs'})
patch_get_token = patch.dict(
self.lauth.tokens,
{
'testfs.get_token': fake_get_token
},
)
mock_rm_token = MagicMock()
patch_rm_token = patch.object(self.lauth, 'rm_token', mock_rm_token)
with patch_opts, patch_get_token, patch_rm_token:
expected_token = 'fnord'
self.lauth.get_tok(expected_token)
mock_rm_token.assert_called_with(expected_token)

def test_get_tok_with_valid_expiration_should_return_token(self):
expected_token = {'expire': time.time()+1}
fake_get_token = MagicMock(return_value=expected_token)
patch_opts = patch.dict(self.lauth.opts, {'eauth_tokens': 'testfs'})
patch_get_token = patch.dict(
self.lauth.tokens,
{
'testfs.get_token': fake_get_token
},
)
mock_rm_token = MagicMock()
patch_rm_token = patch.object(self.lauth, 'rm_token', mock_rm_token)
with patch_opts, patch_get_token, patch_rm_token:
token_name = 'fnord'
actual_token = self.lauth.get_tok(token_name)
mock_rm_token.assert_not_called()
assert expected_token is actual_token, 'Token was not returned'

def test_load_name(self):
valid_eauth_load = {'username': 'test_user',
'show_timeout': False,
Expand Down
47 changes: 46 additions & 1 deletion tests/unit/tokens/test_localfs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# -*- coding: utf-8 -*-
'''
Tests the localfs tokens interface.
'''
from __future__ import absolute_import, print_function, unicode_literals

import os

import salt.utils.files
import salt.exceptions
import salt.tokens.localfs
import salt.utils.files

from tests.support.unit import TestCase, skipIf
from tests.support.helpers import with_tempdir
Expand Down Expand Up @@ -51,3 +55,44 @@ def test_write_token(self, tmpdir):
assert rename.called_with == [
((temp_t_path, t_path), {})
], rename.called_with


class TestLocalFS(unittest.TestCase):
def setUp(self):
# Default expected data
self.expected_data = {'this': 'is', 'some': 'token data'}

@with_tempdir()
def test_get_token_should_return_token_if_exists(self, tempdir):
opts = {'token_dir': tempdir}
tok = salt.tokens.localfs.mk_token(
opts=opts,
tdata=self.expected_data,
)['token']
actual_data = salt.tokens.localfs.get_token(opts=opts, tok=tok)
self.assertDictEqual(self.expected_data, actual_data)

@with_tempdir()
def test_get_token_should_raise_SaltDeserializationError_if_token_file_is_empty(self, tempdir):
opts = {'token_dir': tempdir}
tok = salt.tokens.localfs.mk_token(
opts=opts,
tdata=self.expected_data,
)['token']
with open(os.path.join(tempdir, tok), 'w') as f:
f.truncate()
with self.assertRaises(salt.exceptions.SaltDeserializationError) as e:
salt.tokens.localfs.get_token(opts=opts, tok=tok)

@with_tempdir()
def test_get_token_should_raise_SaltDeserializationError_if_token_file_is_malformed(self, tempdir):
opts = {'token_dir': tempdir}
tok = salt.tokens.localfs.mk_token(
opts=opts,
tdata=self.expected_data,
)['token']
with open(os.path.join(tempdir, tok), 'w') as f:
f.truncate()
f.write('this is not valid msgpack data')
with self.assertRaises(salt.exceptions.SaltDeserializationError) as e:
salt.tokens.localfs.get_token(opts=opts, tok=tok)

0 comments on commit 982ed3d

Please sign in to comment.