Skip to content

Commit

Permalink
Merge pull request facebookresearch#682 from ChrisCummins/feature/und…
Browse files Browse the repository at this point in the history
…o-wrapper

[wrappers] Add a ForkOnStep wrapper.
  • Loading branch information
ChrisCummins authored May 17, 2022
2 parents 51ff9d1 + 366f218 commit 41487a6
Showing 9 changed files with 198 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler_gym/wrappers/BUILD
Original file line number Diff line number Diff line change
@@ -11,6 +11,7 @@ py_library(
"commandline.py",
"core.py",
"datasets.py",
"fork.py",
"llvm.py",
"sqlite_logger.py",
"time_limit.py",
1 change: 1 addition & 0 deletions compiler_gym/wrappers/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ set(WRAPPERS_SRCS
"commandline.py"
"core.py"
"datasets.py"
"fork.py"
"time_limit.py"
"validation.py"
)
2 changes: 2 additions & 0 deletions compiler_gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -44,6 +44,7 @@
IterateOverBenchmarks,
RandomOrderBenchmarks,
)
from compiler_gym.wrappers.fork import ForkOnStep

if config.enable_llvm_env:
from compiler_gym.wrappers.llvm import RuntimePointEstimateReward # noqa: F401
@@ -62,6 +63,7 @@
"ConstrainedCommandline",
"CycleOverBenchmarks",
"CycleOverBenchmarksIterator",
"ForkOnStep",
"IterateOverBenchmarks",
"ObservationWrapper",
"RandomOrderBenchmarks",
75 changes: 75 additions & 0 deletions compiler_gym/wrappers/fork.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This module implements fork wrappers."""
from typing import List

from compiler_gym.envs import CompilerEnv
from compiler_gym.wrappers import CompilerEnvWrapper


class ForkOnStep(CompilerEnvWrapper):
"""A wrapper that creates a fork of the environment before every step.
This wrapper creates a new fork of the environment before every call to
:meth:`env.reset() <compiler_gym.envs.CompilerEnv.reset>`. Because of this,
this environment supports an additional :meth:`env.undo()
<compiler_gym.wrappers.ForkOnStep.undo>` method that can be used to
backtrack.
Example usage:
>>> env = ForkOnStep(compiler_gym.make("llvm-v0"))
>>> env.step(0)
>>> env.actions
[0]
>>> env.undo()
>>> env.actions
[]
:ivar stack: A fork of the environment before every previous call to
:meth:`env.reset() <compiler_gym.envs.CompilerEnv.reset>`, ordered
oldest to newest.
:vartype stack: List[CompilerEnv]
"""

def __init__(self, env: CompilerEnv):
"""Constructor.
:param env: The environment to wrap.
"""
super().__init__(env)
self.stack: List[CompilerEnv] = []

def undo(self) -> CompilerEnv:
"""Undo the previous action.
:returns: Self.
"""
if not self.stack:
return
self.env.close()
self.env = self.stack.pop()
return self.env

def close(self) -> None:
for env in self.stack:
env.close()
self.stack: List[CompilerEnv] = []
self.env.close()
self.custom_close = True

def reset(self, *args, **kwargs):
self.env.reset()
for env in self.stack:
env.close()
self.stack: List[CompilerEnv] = []

def step(self, *args, **kwargs):
self.stack.append(self.env.fork())
return self.env.step(*args, **kwargs)

def fork(self):
raise NotImplementedError
9 changes: 9 additions & 0 deletions docs/source/compiler_gym/wrappers.rst
Original file line number Diff line number Diff line change
@@ -49,6 +49,15 @@ Action space wrappers

.. autoclass:: TimeLimit

.. automethod:: __init__


.. autoclass:: ForkOnStep

.. automethod:: __init__

.. automethod:: undo


Datasets wrappers
-----------------
12 changes: 12 additions & 0 deletions tests/wrappers/BUILD
Original file line number Diff line number Diff line change
@@ -37,6 +37,18 @@ py_test(
],
)

py_test(
name = "fork_test",
srcs = ["fork_test.py"],
deps = [
"//compiler_gym/envs/llvm",
"//compiler_gym/errors",
"//compiler_gym/wrappers",
"//tests:test_main",
"//tests/pytest_plugins:llvm",
],
)

py_test(
name = "llvm_test",
timeout = "long",
11 changes: 11 additions & 0 deletions tests/wrappers/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -42,6 +42,17 @@ cg_py_test(
tests::test_main
)

cg_py_test(
NAME fork_test
SRCS "fork_test.py"
DEPS
compiler_gym::envs::llvm::llvm
compiler_gym::errors::errors
compiler_gym::wrappers::wrappers
tests::test_main
tests::pytest_plugins::llvm
)

cg_py_test(
NAME llvm_test
SRCS "llvm_test.py"
19 changes: 19 additions & 0 deletions tests/wrappers/core_wrappers_test.py
Original file line number Diff line number Diff line change
@@ -311,5 +311,24 @@ def test_wrapped_env_close(env: LlvmEnv):
assert wrapped.service is None


def test_wrapped_env_custom_close(env: LlvmEnv):
"""Test that a custom close() method is called on wrapped environments."""

class MyWrapper(CompilerEnvWrapper):
def __init__(self, env: LlvmEnv):
super().__init__(env)
self.custom_close = False

def close(self):
self.custom_close = True
self.env.close()

env = MyWrapper(env)
assert not env.custom_close

env.close()
assert env.custom_close


if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions tests/wrappers/fork_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Unit tests for //compiler_gym/wrappers."""
from compiler_gym.envs.llvm import LlvmEnv
from compiler_gym.wrappers import ForkOnStep
from tests.test_main import main

pytest_plugins = ["tests.pytest_plugins.llvm"]


def test_ForkOnStep_step(env: LlvmEnv):
with ForkOnStep(env) as env:
env.reset()
assert env.stack == []

env.step(0)
assert env.actions == [0]
assert len(env.stack) == 1
assert env.stack[0].actions == []

env.step(1)
assert env.actions == [0, 1]
assert len(env.stack) == 2
assert env.stack[1].actions == [0]
assert env.stack[0].actions == []


def test_ForkOnStep_reset(env: LlvmEnv):
with ForkOnStep(env) as env:
env.reset()

env.step(0)
assert env.actions == [0]
assert len(env.stack) == 1

env.reset()
assert env.actions == []
assert env.stack == []


def test_ForkOnStep_double_close(env: LlvmEnv):
with ForkOnStep(env) as env:
env.close()
env.close()


def test_ForkOnStep_undo(env: LlvmEnv):
with ForkOnStep(env) as env:
env.reset()

env.step(0)
assert env.actions == [0]
assert len(env.stack) == 1

env.undo()
assert env.actions == []
assert not env.stack

# Undo of an empty stack:
env.undo()
assert env.actions == []
assert not env.stack


if __name__ == "__main__":
main()

0 comments on commit 41487a6

Please sign in to comment.