Skip to content

Commit

Permalink
[Typing] Initial Iptables typing (aerleon#264)
Browse files Browse the repository at this point in the history
ankenyr authored Mar 14, 2023
1 parent 1d41ec1 commit 334e88c
Showing 1 changed file with 73 additions and 30 deletions.
103 changes: 73 additions & 30 deletions aerleon/lib/iptables.py
Original file line number Diff line number Diff line change
@@ -18,10 +18,13 @@

import re
from string import Template # pylint: disable=g-importing-member
from typing import Any, Dict, List, Optional, Set, Tuple, Union

from absl import logging

from aerleon.lib import aclgenerator, nacaddr
from aerleon.lib.nacaddr import IPv4, IPv6
from aerleon.lib.policy import Policy, Term


class Term(aclgenerator.Term):
@@ -65,7 +68,15 @@ class Term(aclgenerator.Term):
'sample': '',
}

def __init__(self, term, filter_name, trackstate, filter_action, af='inet', verbose=True):
def __init__(
self,
term: Term,
filter_name: str,
trackstate: bool,
filter_action: Optional[str],
af: str = 'inet',
verbose: bool = True,
) -> None:
"""Setup a new term.
Args:
@@ -103,7 +114,7 @@ def __init__(self, term, filter_name, trackstate, filter_action, af='inet', verb

self.term_name = '%s_%s' % (self.filter[:1], self.term.name)

def __str__(self):
def __str__(self) -> str:
ret_str = []

# Don't render icmpv6 protocol terms under inet, or icmp under inet6
@@ -361,7 +372,38 @@ def __str__(self):

return '\n'.join(str(v) for v in ret_str if v)

def _CalculateAddresses(self, term_saddr, exclude_saddr, term_daddr, exclude_daddr):
def _CalculateAddresses(
self,
term_saddr: List[Union[IPv4, IPv6]],
exclude_saddr: List[Union[IPv4, IPv6]],
term_daddr: List[Union[IPv4, IPv6]],
exclude_daddr: List[Union[IPv4, IPv6]],
) -> Union[
Tuple[
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
],
Tuple[
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
],
Tuple[
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
],
Tuple[
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
List[Union[IPv4, IPv6]],
],
]:
"""Calculate source and destination address list for a term.
Args:
@@ -438,21 +480,21 @@ def _CalculateAddresses(self, term_saddr, exclude_saddr, term_daddr, exclude_dad

def _FormatPart(
self,
protocol,
saddr,
sport,
daddr,
dport,
options,
tcp_flags,
icmp_type,
code,
track_flags,
sint,
dint,
log_hits,
action,
):
protocol: str,
saddr: Union[IPv6, IPv4, str],
sport: Union[List[Tuple[int, int]], str],
daddr: Union[IPv6, IPv4, str],
dport: Union[List[Tuple[int, int]], str],
options: Union[str, List[str]],
tcp_flags: Union[str, List[str]],
icmp_type: Union[int, str],
code: Union[int, str],
track_flags: Union[Tuple[List[str], List[str]], str],
sint: str,
dint: str,
log_hits: Union[str, bool],
action: str,
) -> List[str]:
"""Compose one iteration of the term parts into a string.
Args:
@@ -501,9 +543,6 @@ def _FormatPart(
if protocol and not proto:
proto = '-p %s' % str(protocol)

# TODO(vklimovs): generalize to all v6 special cases
# Use u32 module as named modules are not available
# everywhere.
if protocol == 'hopopt':
proto = ''
# Select 4 bytes at offset 0x3, mask out all but
@@ -598,7 +637,9 @@ def _FormatPart(
ret_lines.append(' '.join(rval + [action]))
return ret_lines

def _GenerateAddressStatement(self, saddr, daddr):
def _GenerateAddressStatement(
self, saddr: Union[IPv6, IPv4], daddr: Union[IPv6, IPv4]
) -> Tuple[str, str]:
"""Return the address section of an individual iptables rule.
Args:
@@ -622,7 +663,9 @@ def _GenerateAddressStatement(self, saddr, daddr):
dst = '-d %s/%d' % (daddr.network_address, daddr.prefixlen)
return (src, dst)

def _GeneratePortStatement(self, ports, source=False, dest=False):
def _GeneratePortStatement(
self, ports: List[Tuple[int, int]], source: bool = False, dest: bool = False
) -> List[str]:
"""Return the 'port' section of an individual iptables rule.
Args:
@@ -676,7 +719,7 @@ def _GeneratePortStatement(self, ports, source=False, dest=False):
portstrings.append('-m multiport --%sports %s' % (direction, ','.join(norm_ports)))
return portstrings

def _SetDefaultAction(self):
def _SetDefaultAction(self) -> None:
"""If term does not specify action, use filter default action."""
if not self.term.action:
self.term.action[0].value = self.default_action
@@ -698,11 +741,11 @@ class Iptables(aclgenerator.ACLGenerator):
_GOOD_FILTERS = ['INPUT', 'OUTPUT', 'FORWARD']
_GOOD_OPTIONS = ['nostate', 'abbreviateterms', 'truncateterms', 'noverbose']

def __init__(self, pol, exp_info):
def __init__(self, pol: Policy, exp_info: int) -> None:
self.iptables_policies = []
super().__init__(pol, exp_info)

def _BuildTokens(self):
def _BuildTokens(self) -> Tuple[Set[str], Dict[str, Set[str]]]:
"""Build supported tokens for platform.
Returns:
@@ -747,7 +790,7 @@ def _BuildTokens(self):
)
return supported_tokens, supported_sub_tokens

def _WarnIfCustomTarget(self, target):
def _WarnIfCustomTarget(self, target: str) -> None:
"""Emit a warning if a policy's default target is not a built-in chain."""
if target not in self._GOOD_FILTERS:
logging.warning(
@@ -757,7 +800,7 @@ def _WarnIfCustomTarget(self, target):
target,
)

def _TranslatePolicy(self, pol, exp_info):
def _TranslatePolicy(self, pol: Policy, exp_info: int) -> None:
"""Translate a policy from objects into strings."""
default_action = None
good_default_actions = ['ACCEPT', 'DROP']
@@ -871,7 +914,7 @@ def _TranslatePolicy(self, pol, exp_info):
(header, filter_name, filter_type, default_action, new_terms)
)

def SetTarget(self, target, action=None):
def SetTarget(self, target: str, action: Optional[str] = None) -> None:
"""Sets policy's target and default action.
Args:
@@ -886,7 +929,7 @@ def SetTarget(self, target, action=None):
pol[3] = action
self.iptables_policies[0] = tuple(pol)

def __str__(self):
def __str__(self) -> str:
target = []
pretty_platform = '%s%s' % (self._PLATFORM[0].upper(), self._PLATFORM[1:])

0 comments on commit 334e88c

Please sign in to comment.