Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utils create a dictionary decoder #35

Merged
merged 9 commits into from
Jul 19, 2024
77 changes: 62 additions & 15 deletions tests/test_coding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from unittest import TestCase
from utils.coding import DictionaryEncoder
from utils.coding import DictionaryEncoder, DictionaryDecoder

# TODO: rewrite this with subtests and using a bidirectional hash map
# to represent the tabled test cases since they are repetitive.


class TestDictionaryEncoder(TestCase):
Expand All @@ -12,21 +15,65 @@ class TestDictionaryEncoder(TestCase):
Tests the encode method.
"""

test_cases: dict[str, str] = {
"if (([ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0 && [ebp_1 + 0x14].c != 0) || ![ebp_1 + 0x14] == 0) then 387 @ 0x8040da8 else 388 @ 0x8040d8b": "( ( A | ! A & B ) | ! C )",
"if ([ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0) then 387 @ 0x8040da8 else 388 @ 0x8040d8b": "( A | ! A )",
"while (x == 3 && y >= 2):": "( A & B )",
"if ([ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0 && [ebp_1 + 0x14].c != 0) || (![ebp_1 + 0x14] == 0) then 387 @ 0x8040da8 else 388 @ 0x8040d8b": "( A | ! A & B ) | ( ! C )",
}

def test_encode(self) -> None:
"""
Tests the encode method.
"""
for expected, actual in self.test_cases.items():
with self.subTest(expected=expected, actual=actual):
encoder: DictionaryEncoder = DictionaryEncoder()
encoded_str: str = encoder.encode(expected)
self.assertEqual(
encoded_str, actual, "Two values are not equal to each other..."
)
test_cases_encoded: dict[str, str] = {
"if (([ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0 && [ebp_1 + 0x14].c != 0) || ![ebp_1 + 0x14] == 0) then 387 @ 0x8040da8 else 388 @ 0x8040d8b": "( ( A | ! A & ! B ) | ! C )",
"if ([ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0) then 387 @ 0x8040da8 else 388 @ 0x8040d8b": "( A | ! A )",
"while (x == 3 && y >= 2):": "( A & B )",
"if ([ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0 && [ebp_1 + 0x14].c != 0) || (![ebp_1 + 0x14] == 0) then 387 @ 0x8040da8 else 388 @ 0x8040d8b": "( A | ! A & ! B ) | ( ! C )",
}
for encoded_test in test_cases_encoded:
encoder: DictionaryEncoder = DictionaryEncoder()
encoded_str: str = encoder.encode(encoded_test)
answer: str = test_cases_encoded.get(encoded_test)
self.assertEqual(
encoded_str, answer, "Two values are not equal to each other..."
)


class TestDictionaryDecoder(TestCase):
"""
A class to test the DictionaryDecoder class.

Methods
-------
test_decode() -> None:
Tests the decode method.
"""

def test_decode(self) -> None:
"""
Tests the encode method.
"""
test_cases_encoded: dict[str, str] = {
"if (([ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0 && [ebp_1 + 0x14].c != 0) || ![ebp_1 + 0x14] == 0) then 387 @ 0x8040da8 else 388 @ 0x8040d8b": "( ( A | ! A & ! B ) | ! C )",
"if ([ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0) then 387 @ 0x8040da8 else 388 @ 0x8040d8b": "( A | ! A )",
"while (x == 3 && y >= 2):": "( A & B )",
"if ([ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0 && [ebp_1 + 0x14].c != 0) || (![ebp_1 + 0x14] == 0) then 387 @ 0x8040da8 else 388 @ 0x8040d8b": "( A | ! A & ! B ) | ( ! C )",
}

test_cases_decoded: dict[str, str] = {
"( ( A | ! A & ! B ) | ! C )": "( ( [ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0 && [ebp_1 + 0x14].c != 0 ) || [ebp_1 + 0x14] != 0 )",
"( A | ! A )": "( [ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0 )",
"( A & B )": "( x == 3 && y >= 2 )",
"( A | ! A & ! B ) | ( ! C )": "( [ebp_1 + 0x14].b == 0 || [ebp_1 + 0x14].b != 0 && [ebp_1 + 0x14].c != 0 ) || ( [ebp_1 + 0x14] != 0 )",
}

for encoded_test in test_cases_encoded:
encoder: DictionaryEncoder = DictionaryEncoder()
encoded_str: str = encoder.encode(encoded_test)
answer: str = test_cases_encoded.get(encoded_test)

decoder: DictionaryDecoder = DictionaryDecoder(
encoder.get_encoded_dictionary()
)
decoded_str: str = decoder.decode(answer)
# print(decoded_str)
answer: str = test_cases_decoded.get(answer)
# print(answer)
self.assertEqual(
decoded_str, answer, "Two values are not equal to each other..."
)
93 changes: 71 additions & 22 deletions utils/coding.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,9 @@
import string
from dataclasses import dataclass
import re
from typing import Iterator
import itertools


@dataclass
class Boolean:
"""
A class to represent a boolean expression with raw and encoded forms.

Attributes
----------
raw : str
The raw boolean expression.
encoded : str
The encoded boolean expression.
"""

raw: str
encoded: str


class NameGenerator:
"""
A class to generate unique names for boolean conditions.
Expand All @@ -46,7 +28,7 @@ def __init__(self) -> None:
Constructs all the necessary attributes for the NameGenerator object.
"""
self.generated_dictionary_keys: dict[str, str] = {}
self.prev_state: set[str] = set()
self.prev_state: list[str] = []

def generate_name(self, conditional: str) -> str:
"""
Expand All @@ -62,11 +44,13 @@ def generate_name(self, conditional: str) -> str:
str
The generated name.
"""
replaced_conditional: str = re.sub("!=", "==", conditional)
replaced_conditional = re.sub("!=", "==", conditional)
val: str | None = self.generated_dictionary_keys.get(replaced_conditional)
if val is None:
gen_key: str = next(self.generate_unique_uppercase_string())
self.generated_dictionary_keys[replaced_conditional] = gen_key
if conditional != replaced_conditional:
return "! " + gen_key
return gen_key
else:
if conditional != replaced_conditional:
Expand All @@ -87,9 +71,12 @@ def generate_unique_uppercase_string(self) -> Iterator[str]:
for length in itertools.count(1):
for s in itertools.product(string.ascii_uppercase, repeat=length):
if "".join(s) not in self.prev_state:
self.prev_state.add("".join(s))
self.prev_state.append("".join(s))
yield "".join(s)

def return_encoded_value(self) -> dict[str, str]:
return self.generated_dictionary_keys


class DictionaryEncoder:
"""
Expand All @@ -110,7 +97,7 @@ def __init__(self) -> None:
"""
Constructs all the necessary attributes for the DictionaryEncoder object.
"""
self.name_generator: NameGenerator = NameGenerator()
self.name_generator = NameGenerator()

def encode(self, mlil_if_string: str) -> str:
"""
Expand Down Expand Up @@ -151,3 +138,65 @@ def encode(self, mlil_if_string: str) -> str:
encoded_parts.append(code)

return " ".join(encoded_parts)

def get_encoded_dictionary(self) -> dict[str, str]:
return self.name_generator.return_encoded_value()


class DictionaryDecoder:
def __init__(self, generated_dictionary_keys):
"""
Initialize the DictionaryDecoder with a given mapping.

Parameters
----------
mapping : dict
A dictionary mapping encoded values to their original MLIL values.
"""
self.mapping: dict[str, str] = generated_dictionary_keys

def decode(self, encoded_str):
"""
Decode an encoded boolean statement back to the original values.

Parameters
----------
encoded_str : str
The encoded string.

Returns
-------
str
The decoded string.
"""

LOGICAL_OPERATORS_DECODER: str = r"(\w+|\|\||&&|[!()&|])"
tokens: list[str] = re.split(LOGICAL_OPERATORS_DECODER, encoded_str)
tokens = [cond.strip() for cond in tokens if cond.strip()]
decoded_parts: list[str] = []
i: int = 0
while i < len(tokens):
if tokens[i] in {"|", "&", "(", ")"}:
if tokens[i] == "|":
decoded_parts.append("||")
elif tokens[i] == "&":
decoded_parts.append("&&")
else:
decoded_parts.append(tokens[i])
elif tokens[i] == "!":
i += 1
# TODO: Replace this code with a bidirectional hashmap
replace_not_equals: str = list(self.mapping.keys())[
list(self.mapping.values()).index(tokens[i])
]
replace_not_equals = re.sub("==", "!=", replace_not_equals)
decoded_parts.append(replace_not_equals)
else:
decoded_parts.append(
list(self.mapping.keys())[
Nytro1O1 marked this conversation as resolved.
Show resolved Hide resolved
list(self.mapping.values()).index(tokens[i])
]
)
i += 1

return " ".join(decoded_parts)