Skip to content

Commit

Permalink
Add Safe subclass semantics
Browse files Browse the repository at this point in the history
  • Loading branch information
przemub committed Jul 10, 2022
1 parent a1f99e4 commit 062e260
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 12 deletions.
9 changes: 8 additions & 1 deletion examples/assignment.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
def reassign(a):
b = a
return b
return b


def chain(a):
c = a
d = c
e = d
return e
29 changes: 20 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import ast
import sys
from _ast import FunctionDef, stmt, Return, Name
from _ast import FunctionDef, stmt, Return, Name, Assign
from typing import TypeVar

T = TypeVar("T", bound=stmt)
Expand Down Expand Up @@ -35,17 +35,20 @@ def load_file(filename: str) -> tuple[str, ast.AST]:


class TaintVisitor(ast.NodeVisitor):
def _warn_tainted(self, node, subnode):
def _warn_tainted(self, node):
print(f"Line {node.lineno} tainted:")

print(ast.get_source_segment(self.source, node, padded=True))
print("-> ", end="")
print(ast.get_source_segment(self.source, subnode, padded=True))
tab = 0
while node:
print("->" * tab, end="")
print(ast.get_source_segment(self.source, node, padded=True))

print()
tab += 1
node = self.tainted_because.get(node, None)

def __init__(self, tainted_variables: set, source: str):
self.tainted_variables = tainted_variables
self.tainted_because = {}
self.source = source
self._tainted_nodes: list[ast.AST] = []

Expand All @@ -57,6 +60,15 @@ def visit_Name(self, node: Name):
if node.id in self.tainted_variables:
self._tainted_nodes.append(node)

def visit_Assign(self, node: Assign):
visitor = TaintVisitor(self.tainted_variables, self.source)
visitor.visit(node.value)

if visitor.tainted_nodes:
for variable in node.targets:
self.tainted_variables.add(variable.id)
self.tainted_because[variable.id] = node

def visit_Return(self, node: Return):
"""
We want to warn when returning a tainted variable, therefore
Expand All @@ -69,10 +81,9 @@ def visit_Return(self, node: Return):
visitor.visit(node.value)

if visitor.tainted_nodes:
self._tainted_nodes += [node]

for item in visitor.tainted_nodes:
self._warn_tainted(node, item)
self._tainted_nodes += [node]
self._warn_tainted(node)


def taint(function: FunctionDef, argument: str, source):
Expand Down
32 changes: 30 additions & 2 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
from unittest import TestCase

from main import load_file, find_subnode, TaintVisitor
from utils import is_safe, mark_safe


class TainerTestCase(TestCase):
class TainterTestCase(TestCase):
def test_simple(self):
source, my_ast = load_file("examples/return_argument.py")
function = find_subnode(my_ast, FunctionDef, name="simple")
Expand All @@ -24,7 +25,7 @@ def test_with_op(self):
self.assertIn(tainted_return, tainter.tainted_nodes)


def test_with_op(self):
def test_with_call(self):
source, my_ast = load_file("examples/return_argument.py")
function = find_subnode(my_ast, FunctionDef, name="with_call")
tainter = TaintVisitor({"a"}, source)
Expand All @@ -42,3 +43,30 @@ def test_reassign(self):

tainted_return = find_subnode(function, Return)
self.assertIn(tainted_return, tainter.tainted_nodes)

def test_chain(self):
source, my_ast = load_file("examples/assignment.py")
function = find_subnode(my_ast, FunctionDef, name="chain")
tainter = TaintVisitor({"a"}, source)
tainter.visit(function)

tainted_return = find_subnode(function, Return)
self.assertIn(tainted_return, tainter.tainted_nodes)


class MarkSafeTestCase(TestCase):
def test_mark_safe(self):
def a():
pass

self.assertFalse(is_safe(a))

mark_safe(a)
self.assertTrue(is_safe(a))

def test_decorator(self):
@mark_safe
def a():
pass

self.assertTrue(is_safe(a))
57 changes: 57 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""Utils for use by program creators."""
from abc import abstractmethod
from typing import Callable, TypeVar

# Attribute to be set on safe functions.
_SAFE_ATTRIBUTE = "__tainter_safe"


class _SafeMeta(type):
"""
A metaclass which instead of using normal subclassing semantics,
simply checks if _SAFE_ATTRIBUTE is declared.
By declaring _SAFE_ATTRIBUTE you tell the tainter that this method
processes its input safely.
See: mark_safe.
"""
def __instancecheck__(self, instance):
"""Instance is a subinstance iff it has _SAFE_ATTRIBUTE."""
return getattr(instance, _SAFE_ATTRIBUTE, None) is not None

@classmethod
def __subclasscheck__(cls, subclass):
"""Class is a subclass iff it has _SAFE_ATTRIBUTE."""
return getattr(subclass, _SAFE_ATTRIBUTE, None) is not None


class Safe(metaclass=_SafeMeta):
"""
A value returned when this class sub-instance is called
is untainted even if a tainted variable is passed to it.
"""

@abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError()


T = TypeVar("T", bound=Callable)


def mark_safe(func: T) -> T:
"""
This decorator marks a callable as "safe".
You can check that a callable is safe by:
1) checking if the callable is an instance of Safe (isinstance(func, Safe))
2) using is_safe(), which does the above
3) checking for the SAFE_ATTRIBUTE (unrecommended, since the semantics may
change.
"""

setattr(func, _SAFE_ATTRIBUTE, True)
return func


def is_safe(func: Callable) -> bool:
return isinstance(func, Safe)

0 comments on commit 062e260

Please sign in to comment.