Skip to content

Commit

Permalink
Improved query, print & exception handling in REPL Tool (langchain-ai…
Browse files Browse the repository at this point in the history
…#4997)

Update to pull request langchain-ai#3215

Summary:
1) Improved the sanitization of query (using regex), by removing python
command (since gpt-3.5-turbo sometimes assumes python console as a
terminal, and runs python command first which causes error). Also
sometimes 1 line python codes contain single backticks.
2) Added 7 new test cases.

For more details, view the previous pull request.

---------

Co-authored-by: Deepak S V <[email protected]>
  • Loading branch information
svdeepak99 and svdeepak99 authored May 22, 2023
1 parent 785502e commit 49ca027
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 17 deletions.
34 changes: 17 additions & 17 deletions langchain/tools/python/tool.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""A tool for running python code in a REPL."""

import ast
import re
import sys
from contextlib import redirect_stdout
from io import StringIO
from typing import Any, Dict, Optional

Expand All @@ -19,14 +21,13 @@ def _get_default_python_repl() -> PythonREPL:
return PythonREPL(_globals=globals(), _locals=None)


_MD_PY_BLOCK = "```python"


def sanitize_input(query: str) -> str:
query = query.strip()
if query[: len(_MD_PY_BLOCK)] == _MD_PY_BLOCK:
query = query[len(_MD_PY_BLOCK) :].strip()
query = query.strip("`").strip()
# Remove whitespace, backtick & python (if llm mistakes python console as terminal)

# Removes `, whitespace & python from start
query = re.sub(r"^(\s|`)*(?i:python)?\s*", "", query)
# Removes whitespace & ` from end
query = re.sub(r"(\s|`)*$", "", query)
return query


Expand Down Expand Up @@ -101,19 +102,18 @@ def _run(
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
module_end = ast.Module(tree.body[-1:], type_ignores=[])
module_end_str = ast.unparse(module_end) # type: ignore
io_buffer = StringIO()
try:
return eval(module_end_str, self.globals, self.locals)
with redirect_stdout(io_buffer):
ret = eval(module_end_str, self.globals, self.locals)
if ret is None:
return io_buffer.getvalue()
else:
return ret
except Exception:
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
try:
with redirect_stdout(io_buffer):
exec(module_end_str, self.globals, self.locals)
sys.stdout = old_stdout
output = mystdout.getvalue()
except Exception as e:
sys.stdout = old_stdout
output = repr(e)
return output
return io_buffer.getvalue()
except Exception as e:
return "{}: {}".format(type(e).__name__, str(e))

Expand Down
110 changes: 110 additions & 0 deletions tests/unit_tests/tools/python/test_python.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test Python REPL Tools."""
import sys

import numpy as np
import pytest

from langchain.tools.python.tool import (
Expand All @@ -17,6 +18,18 @@ def test_python_repl_tool_single_input() -> None:
assert int(tool.run("print(1 + 1)").strip()) == 2


def test_python_repl_print() -> None:
program = """
import numpy as np
v1 = np.array([1, 2, 3])
v2 = np.array([4, 5, 6])
dot_product = np.dot(v1, v2)
print("The dot product is {:d}.".format(dot_product))
"""
tool = PythonREPLTool()
assert tool.run(program) == "The dot product is 32.\n"


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
Expand All @@ -27,6 +40,103 @@ def test_python_ast_repl_tool_single_input() -> None:
assert tool.run("1 + 1") == 2


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_return() -> None:
program = """
```
import numpy as np
v1 = np.array([1, 2, 3])
v2 = np.array([4, 5, 6])
dot_product = np.dot(v1, v2)
int(dot_product)
```
"""
tool = PythonAstREPLTool()
assert tool.run(program) == 32

program = """
```python
import numpy as np
v1 = np.array([1, 2, 3])
v2 = np.array([4, 5, 6])
dot_product = np.dot(v1, v2)
int(dot_product)
```
"""
assert tool.run(program) == 32


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_print() -> None:
program = """python
string = "racecar"
if string == string[::-1]:
print(string, "is a palindrome")
else:
print(string, "is not a palindrome")"""
tool = PythonAstREPLTool()
assert tool.run(program) == "racecar is a palindrome\n"


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_repl_print_python_backticks() -> None:
program = "`print('`python` is a great language.')`"
tool = PythonAstREPLTool()
assert tool.run(program) == "`python` is a great language.\n"


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_raise_exception() -> None:
data = {"Name": ["John", "Alice"], "Age": [30, 25]}
program = """
import pandas as pd
df = pd.DataFrame(data)
df['Gender']
"""
tool = PythonAstREPLTool(locals={"data": data})
expected_outputs = (
"KeyError: 'Gender'",
"ModuleNotFoundError: No module named 'pandas'",
)
assert tool.run(program) in expected_outputs


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_one_line_print() -> None:
program = 'print("The square of {} is {:.2f}".format(3, 3**2))'
tool = PythonAstREPLTool()
assert tool.run(program) == "The square of 3 is 9.00\n"


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_one_line_return() -> None:
arr = np.array([1, 2, 3, 4, 5])
tool = PythonAstREPLTool(locals={"arr": arr})
program = "`(arr**2).sum() # Returns sum of squares`"
assert tool.run(program) == 55


@pytest.mark.skipif(
sys.version_info < (3, 9), reason="Requires python version >= 3.9 to run."
)
def test_python_ast_repl_one_line_exception() -> None:
program = "[1, 2, 3][4]"
tool = PythonAstREPLTool()
assert tool.run(program) == "IndexError: list index out of range"


def test_sanitize_input() -> None:
query = """
```
Expand Down

0 comments on commit 49ca027

Please sign in to comment.