Skip to content

Commit

Permalink
modify allowed modules setting (microsoft#182)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShilinHe authored Jan 30, 2024
1 parent e299e75 commit 87ce39f
Showing 1 changed file with 26 additions and 18 deletions.
44 changes: 26 additions & 18 deletions taskweaver/code_interpreter/code_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,37 @@ def visit_Call(self, node):
return True

def visit_Import(self, node):
if len(self.allowed_modules) > 0:
for alias in node.names:
if "." in alias.name:
module_name = alias.name.split(".")[0]
else:
module_name = alias.name
if len(self.allowed_modules) > 0 and module_name not in self.allowed_modules:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
f"=> Importing module '{module_name}' is not allowed. ",
)

def visit_ImportFrom(self, node):
if len(self.allowed_modules) > 0:
if "." in node.module:
module_name = node.module.split(".")[0]
for alias in node.names:
if "." in alias.name:
module_name = alias.name.split(".")[0]
else:
module_name = node.module
module_name = alias.name
if len(self.allowed_modules) > 0 and module_name not in self.allowed_modules:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
f"=> Importing from module '{node.module}' is not allowed.",
f"=> Importing module '{module_name}' is not allowed. ",
)
elif len(self.allowed_modules) == 0:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
f"=> Importing module '{module_name}' is not allowed. ",
)

def visit_ImportFrom(self, node):
if "." in node.module:
module_name = node.module.split(".")[0]
else:
module_name = node.module
if len(self.allowed_modules) > 0 and module_name not in self.allowed_modules:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
f"=> Importing from module '{node.module}' is not allowed.",
)
elif len(self.allowed_modules) == 0:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
f"=> Importing from module '{node.module}' is not allowed.",
)

def generic_visit(self, node):
super().generic_visit(node)
Expand Down

0 comments on commit 87ce39f

Please sign in to comment.