Skip to content

Commit

Permalink
Support calls :)
Browse files Browse the repository at this point in the history
  • Loading branch information
przemub committed Jul 10, 2022
1 parent 42a40a5 commit 3116c18
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 37 deletions.
15 changes: 12 additions & 3 deletions examples/call.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from utils import mark_safe
from utils import mark_safe, mark_output


@mark_safe
Expand All @@ -19,5 +19,14 @@ def unsafe_call(a):


def print_test(a):
d = int(a)
return d
print(a)


@mark_output
def _generate_a_report(a):
# sys.write(a, "report.pdf")
pass


def mark_output_test(a):
_generate_a_report(a)
47 changes: 39 additions & 8 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from termcolor import colored

import utils

T = TypeVar("T", bound=ast.AST)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -135,6 +137,26 @@ def safe_functions(self) -> list[str]:

return results

@cached_property
def output_functions(self) -> list[str]:
"""Collect and return the list of functions marked with mark_output."""
# TODO: Collect also the imported functions.
# TODO: Collect functions marked without the @ decorator syntax.
# This will be useful for marking imported functions.

results = []

tree = ast.parse(self._source)
functions = find_subnodes(tree, [FunctionDef])
for function in functions:
for decorator in function.decorator_list:
if not isinstance(decorator, Name):
continue
if decorator.id == "mark_output":
results.append(function.name)

return results + utils.OUTPUT_FUNCTIONS

@property
def tainted_variables(self):
return self._tainted_because.keys()
Expand All @@ -161,7 +183,7 @@ def visit_Call(self, node: Call):
"""
If the function call was determined safe, don't go down.
If the function call was determined an output, shout when a tainted
variable is used.
variable is passed as an argument.
Otherwise, continue as usual.
"""

Expand All @@ -170,11 +192,24 @@ def visit_Call(self, node: Call):
return
if node.func.id in self.safe_functions:
return
#if node.func.id in self.output_functions:
# pass
if node.func.id in self.output_functions:
for arg in node.args:
self._check_if_not_tainted(arg)

return super().generic_visit(node)

def _check_if_not_tainted(self, node: ast.AST):
"""
Node value will be output to the user, therefore check if it uses
anything tainted and shout if so.
"""
visitor = TaintVisitor(node, self._tainted_because, self._source)

if visitor.tainted_nodes:
for _ in visitor.tainted_nodes:
self._tainted_nodes += [node]
self._warn_tainted(node)


def visit_Return(self, node: Return):
"""
Expand All @@ -184,12 +219,8 @@ def visit_Return(self, node: Return):
if node.value is None:
return

visitor = TaintVisitor(node.value, self._tainted_because, self._source)
self._check_if_not_tainted(node.value)

if visitor.tainted_nodes:
for item in visitor.tainted_nodes:
self._tainted_nodes += [node]
self._warn_tainted(node)


def taint(function: FunctionDef, argument: str, source):
Expand Down
57 changes: 31 additions & 26 deletions test.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,57 @@
from __future__ import annotations

from _ast import FunctionDef, Return
from unittest import TestCase, expectedFailure

from main import load_file, find_subnode, taint, TaintVisitor
from utils import is_safe, mark_safe


def taint_argument(file, function_name, tainted_arg="a", tainted=True):
def taint_argument(
file, function_name, tainted_arg="a", tainted: bool | None = True
):
"""
Have a test unit load a Python file, find a function in it,
taint the argument, run the tainter and make sure that
the returned value was tainted if tainted=True.
"""

def decorator(test_unit):
source, ast = load_file("examples/" + file)
function = find_subnode(
ast, FunctionDef, name=function_name
)
tainter = taint(
function, tainted_arg, source
)

def call(self: "TainterTestCase"):
source, ast = load_file("examples/" + file)
function = find_subnode(ast, FunctionDef, name=function_name)
tainter = taint(function, tainted_arg, source)

test_unit(self, tainter)
if tainted:
self.assertTaintedReturn(tainter)
if tainted is True:
self.assertTainted(function_name, tainter)
elif tainted is False:
self.assertNotTainted(function_name, tainter)
else:
self.assertNotTaintedReturn(tainter)
pass

return call

return decorator


class TainterTestCase(TestCase):
def assertTaintedReturn(self, tainter: TaintVisitor):
tainted_return = find_subnode(tainter.tree, Return)
self.assertIn(
tainted_return,
tainter.tainted_nodes,
msg="The return was not tainted, but it should be.",
)
def assertTainted(self, func: str, tainter: TaintVisitor):
self.assertInOutput(f"def {func}", tainter)

def assertNotTaintedReturn(self, tainter: TaintVisitor):
tainted_return = find_subnode(tainter.tree, Return)
self.assertNotIn(
tainted_return,
tainter.tainted_nodes,
msg="The return was tainted, but it should not be.",
)
def assertNotTainted(self, func: str, tainter: TaintVisitor):
self.assertNotInOutput(f"def {func}", tainter)

def assertInOutput(self, text: str, tainter: TaintVisitor):
self.assertIn(
text, tainter.output, msg=f"{text} was not found in the output."
)

def assertNotInOutput(self, text: str, tainter: TaintVisitor):
self.assertNotIn(
text, tainter.output, msg=f"{text} was found in the output."
)

@taint_argument("return_argument.py", "simple")
def test_simple(self, tainter):
self.assertInOutput("return a", tainter)
Expand Down Expand Up @@ -99,6 +96,14 @@ def test_safe_call(self, tainter):
def test_unsafe_call(self, tainter):
pass

@taint_argument("call.py", "print_test")
def test_print(self, tainter):
pass

@taint_argument("call.py", "mark_output_test")
def test_mark_output(self, tainter):
pass


class MarkSafeTestCase(TestCase):
def test_mark_safe(self):
Expand Down
9 changes: 9 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ def __call__(self, *args, **kwargs):
raise NotImplementedError()


# Some functions should be considered output by default.
OUTPUT_FUNCTIONS = [
"print",
"sys.write"
]


class _OutputMeta(type):
"""
A metaclass which instead of using normal subclassing semantics,
Expand All @@ -61,6 +68,8 @@ class _OutputMeta(type):

def __instancecheck__(self, instance):
"""Instance is a subinstance iff it has _SAFE_ATTRIBUTE."""
if instance.__name__ in self.OUTPUT_FUNCTIONS:
return True
return getattr(instance, _OUTPUT_ATTRIBUTE, None) is not None

@classmethod
Expand Down

0 comments on commit 3116c18

Please sign in to comment.