Skip to content

Commit

Permalink
modified tree.py, setup.py
Browse files Browse the repository at this point in the history
  • Loading branch information
drinder committed Apr 2, 2023
1 parent 5c30481 commit 48d6aad
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
Binary file modified codes/__pycache__/setup.cpython-39.pyc
Binary file not shown.
4 changes: 2 additions & 2 deletions codes/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def fun(x,t,Net):
OP2 = np.array([['+', 2, np.add], ['-', 2, np.subtract], ['*', 2, np.multiply], ['/', 2, divide], ['d', 2, Diff], ['d^2', 2, Diff2]])
# VARS = np.array([['u', 0, u], ['x', 0, x], ['0', 0, zeros], ['ux', 0, ux], ['uxx', 0, uxx], ['u^2', 0, u**2]])
VARS = np.array([['u', 0, u], ['x', 0, x], ['0', 0, zeros], ['ux', 0, ux]])
den = np.array([['x', 0, x]])
DENOMINATOR = np.array([['x', 0, x]])

# else:
# ALL = np.array([['sin', 1, np.sin], ['cos', 1, np.cos], ['log', 1, np.log], ['+', 2, np.add], ['-', 2, np.subtract],
Expand All @@ -365,6 +365,6 @@ def fun(x,t,Net):
# OP2 = np.array(
# [['+', 2, np.add], ['-', 2, np.subtract], ['*', 2, np.multiply], ['/', 2, np.divide], ['d', 2, Diff]])
# VARS = np.array([['u', 0, u], ['t', 0, t], ['x', 0, x]])
# den = np.array([['t', 0, t], ['x', 0, x]])
# DENOMINATOR = np.array([['t', 0, t], ['x', 0, x]])

pde_lib, err_lib = [], []
8 changes: 4 additions & 4 deletions codes/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(self, max_depth, p_var):
continue
for j in range(parent.child_num):
# rule 1
if parent.name in {'d', 'd^2'} and j == 1: # j == 0 为d的左侧节点,j == 1为d的右侧节点
node = den[np.random.randint(0, len(den))] # 随机产生一个微分运算的denominator,一般是xyt
if parent.name in {'d', 'd^2'} and j == 1:
node = DENOMINATOR[np.random.randint(0, len(DENOMINATOR))]
node = Node(depth=depth, idx=len(self.tree[depth]), parent_idx=parent_idx, name=node[0],
var=node[2], full=node, child_num=int(node[1]), child_st=None)
self.tree[depth].append(node)
Expand Down Expand Up @@ -126,7 +126,7 @@ def mutate(self, p_mute): #直接替换原有tree中的某个节点,用同类
# print('mutate!')
if num_child == 0: # 叶子节点
node = VARS[np.random.randint(0, len(VARS))] # rule 2: 叶节点必须是var,不能是op
while node[0] == temp or (parent.name in {'d', 'd^2'} and node[0] not in den[:, 0]):# rule 3: 如果编译前后结果重复,或者d的节点不在den中(即出现不能求导的对象),则重新抽取
while node[0] == temp or (parent.name in {'d', 'd^2'} and node[0] not in DENOMINATOR[:, 0]):# rule 3: 如果编译前后结果重复,或者d的节点不在den中(即出现不能求导的对象),则重新抽取
if simple_mode and parent.name in {'d', 'd^2'} and node[0] == 'x': # simple_mode中,遇到对于x的导数,直接停止变异
break
node = VARS[np.random.randint(0, len(VARS))] # 重新抽取一个vars
Expand All @@ -141,7 +141,7 @@ def mutate(self, p_mute): #直接替换原有tree中的某个节点,用同类
elif num_child == 2:
node = OP2[np.random.randint(0, len(OP2))]
right = self.tree[depth + 1][current.child_st + 1].name
while node[0] == temp or (node[0] in {'d', 'd^2'} and right not in den[:, 0]):# rule 4: 避免重复,避免生成d以打乱树结构(新d的右子节点不是x)
while node[0] == temp or (node[0] in {'d', 'd^2'} and right not in DENOMINATOR[:, 0]):# rule 4: 避免重复,避免生成d以打乱树结构(新d的右子节点不是x)
node = OP2[np.random.randint(0, len(OP2))]
else:
raise NotImplementedError("Error occurs!")
Expand Down

0 comments on commit 48d6aad

Please sign in to comment.