Skip to content

Commit

Permalink
Merge pull request #38 from bliutech/22-solver-implement-sympy-simpli…
Browse files Browse the repository at this point in the history
…fier

22 solver implement sympy simplifier
  • Loading branch information
bliutech authored Jul 19, 2024
2 parents 7b8f948 + ea7a540 commit 56be238
Showing 1 changed file with 102 additions and 0 deletions.
102 changes: 102 additions & 0 deletions solver/passes/sympy_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import override
from parser.ast import (
Var,
Expr,
VarExpr,
NotExpr,
ParenExpr,
AndExpr,
OrExpr,
)
from parser.visitor import Visitor, RetVisitor
from parser.parse import Parser
from parser.lex import Lexer

import sympy
from sympy.logic.boolalg import And, Or, Not


def run_pass(ast: Expr) -> Expr:
# Sympy visitor for mapping variable
v: SympyMappingVisitor = SympyMappingVisitor()
ast.accept(v)

tv: TranslateToSympy = TranslateToSympy(v.symbols)

p: sympy.Basic = ast.acceptRet(tv)

simplifiedExpr: sympy.Basic = sympy.simplify_logic(p)
simplifiedStr: str = str(simplifiedExpr)

l: Lexer = Lexer()
l.lex(simplifiedStr)

pa: Parser = Parser()
ast: Expr = pa.parse(l.tokens)

return ast


class SympyMappingVisitor(Visitor):
"""
A visitor that visits each node in the AST and adds Var nodes to the symbolMap.
"""

def __init__(self) -> None:
self.symbolMap: dict[str, sympy.Symbol] = {}

@override
def visitVar(self, va: Var) -> None:
self.symbolMap[va.name] = sympy.Symbol(va.name)


class TranslateToSympy(RetVisitor[sympy.Basic]):
"""
A visitor that visits each node in the AST and
returns an expression translated to Sympy logic.
"""

def __init__(self, symbols: dict[str, sympy.Symbol]) -> None:
self.symbols = symbols

@override
def visitVarExpr(self, vex: VarExpr) -> sympy.Basic:
first: sympy.Basic = vex.first.acceptRet(self)
if vex.second:
second: sympy.Basic = vex.second.first.acceptRet(self)
if isinstance(vex.second, AndExpr):
return And(first, second)
elif isinstance(vex.second, OrExpr):
return Or(first, second)
return first

@override
def visitNotExpr(self, nex: NotExpr) -> sympy.Basic:
return Not(nex.first.acceptRet(self))

@override
def visitParenExpr(self, pex: ParenExpr) -> sympy.Basic:
return pex.first.acceptRet(self)

@override
def visitAndExpr(self, aex: AndExpr) -> sympy.Basic:
pass

@override
def visitOrExpr(self, oex: OrExpr) -> sympy.Basic:
pass

@override
def visitVar(self, va: Var) -> sympy.Basic:
return self.symbols[va.name]


if __name__ == "__main__":
prog = "B & A | !A"
l: Lexer = Lexer()
l.lex(prog)

p: Parser = Parser()
ast: Expr = p.parse(l.tokens)

run_pass(ast)

0 comments on commit 56be238

Please sign in to comment.