Skip to content

Commit

Permalink
Add a regex solution to remove docstring for coding.correction.run_wi…
Browse files Browse the repository at this point in the history
…th_correction(). The motivation is that docstring in source field can't be parsed correctly in the current code parsing implementation.

Plus add autofix args into __call__ and eval() functions of lf.PythonCode.

PiperOrigin-RevId: 597969045
  • Loading branch information
yifenglou authored and langfun authors committed Jan 12, 2024
1 parent c966215 commit 64b027b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
8 changes: 7 additions & 1 deletion langfun/core/coding/python/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python code error correction."""

import re
from typing import Any
import langfun.core as lf
from langfun.core.coding.python import errors
Expand All @@ -31,6 +31,11 @@ class CorrectedCode(pg.Object):
corrected_code: str


def remove_docstrings(code):
pattern = re.compile(r"(def .+?:\s*?)('''|\"\"\")((.|\s)*?)(\2)", re.DOTALL)
return pattern.sub(r"\1", code)


def run_with_correction(
code: str,
error: str | None = None,
Expand Down Expand Up @@ -81,6 +86,7 @@ def run_with_correction(
# pytype: disable=import-error
# pylint: enable=g-import-not-at-top

code = remove_docstrings(code)
if max_attempts == 0:
result = execution.run(
code,
Expand Down
8 changes: 8 additions & 0 deletions langfun/core/coding/python/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __call__(
sandbox: bool | None = None,
timeout: int | None = 5,
global_vars: dict[str, Any] | None = None,
autofix: int = 3,
autofix_lm: lf.LanguageModel = lf.contextual(),
) -> Any:
"""Returns the value of the last expression from the source.
Expand All @@ -100,6 +101,8 @@ def __call__(
timeout: Timeout in seconds. If None, there is no timeout. Applicable when
sandbox is set to True.
global_vars: Global variables that could be accessed from the source code.
autofix: Number of attempts to auto fix the generated code. If 0, autofix
is disabled.
autofix_lm: Language model to be used. If not specified, it will try to
use the `lm` under `lf.context`.
Expand All @@ -115,6 +118,7 @@ def __call__(
global_vars=global_vars,
sandbox=sandbox,
timeout=timeout,
max_attempts=autofix,
lm=autofix_lm,
returns_code=True,
)
Expand All @@ -127,6 +131,7 @@ def eval(
sandbox: bool | None = None,
timeout: int | None = 5,
global_vars: dict[str, Any] | None = None,
autofix: int = 3,
autofix_lm: lf.LanguageModel = lf.contextual(),
) -> Any | tuple[Any, str]:
"""Evaluates the code and return a dict of local variable names to values.
Expand All @@ -139,6 +144,8 @@ def eval(
timeout: Timeout in seconds. If None, there is no timeout. Applicable when
sandbox is set to True.
global_vars: Global variables that could be accessed from the source code.
autofix: Number of attempts to auto fix the generated code. If 0, autofix
is disabled. Auto-fix is not supported for 'json' protocol.
autofix_lm: Language model to be used. If not specified, it will try to
use the `lm` under `lf.context`.
Expand All @@ -157,6 +164,7 @@ def eval(
sandbox=sandbox,
timeout=timeout,
outputs_intermediate=True,
max_attempts=autofix,
lm=autofix_lm,
returns_code=True,
)
Expand Down

0 comments on commit 64b027b

Please sign in to comment.