forked from facebookresearch/CompilerGym
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[wrappers] Add synchronous SQLite database wrapper.
- Loading branch information
1 parent
7d83233
commit 483f640
Showing
6 changed files
with
298 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,245 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""This module implements a wrapper that logs state transitions to an sqlite | ||
database. | ||
""" | ||
import logging | ||
import pickle | ||
import sqlite3 | ||
import zlib | ||
from pathlib import Path | ||
from time import time | ||
from typing import Iterable, Optional, Union | ||
|
||
import numpy as np | ||
|
||
from compiler_gym.envs import LlvmEnv | ||
from compiler_gym.spaces import Reward | ||
from compiler_gym.util.gym_type_hints import ActionType | ||
from compiler_gym.util.timer import Timer, humanize_duration | ||
from compiler_gym.views import ObservationSpaceSpec | ||
from compiler_gym.wrappers import CompilerEnvWrapper | ||
|
||
DB_CREATION_SCRIPT = """ | ||
CREATE TABLE IF NOT EXISTS States ( | ||
benchmark_uri TEXT NOT NULL, -- The URI of the benchmark. | ||
done INTEGER NOT NULL, -- 0 = False, 1 = True. | ||
ir_instruction_count_oz_reward REAL NULLABLE, | ||
state_id TEXT NOT NULL, -- 40-char sha1. | ||
actions TEXT NOT NULL, -- Decode: [int(x) for x in field.split()] | ||
PRIMARY KEY (benchmark_uri, actions), | ||
FOREIGN KEY (state_id) REFERENCES Observations(state_id) ON UPDATE CASCADE | ||
); | ||
CREATE TABLE IF NOT EXISTS Observations ( | ||
state_id TEXT NOT NULL, -- 40-char sha1. | ||
ir_instruction_count INTEGER NOT NULL, | ||
compressed_llvm_ir BLOB NOT NULL, -- Decode: zlib.decompress(...) | ||
pickled_compressed_programl BLOB NOT NULL, -- Decode: pickle.loads(zlib.decompress(...)) | ||
autophase TEXT NOT NULL, -- Decode: np.array([int(x) for x in field.split()], dtype=np.int64) | ||
instcount TEXT NOT NULL, -- Decode: np.array([int(x) for x in field.split()], dtype=np.int64) | ||
PRIMARY KEY (state_id) | ||
); | ||
""" | ||
|
||
|
||
class SynchronousSqliteLogger(CompilerEnvWrapper): | ||
"""A wrapper for an LLVM environment that logs all transitions to an sqlite | ||
database. | ||
Wrap an existing LLVM environment and then use it as per normal: | ||
>>> env = SynchronousSqliteLogger( | ||
... env=gym.make("llvm-autophase-ic-v0"), | ||
... db_path="example.db", | ||
... ) | ||
Connect to the database file you specified: | ||
.. code-block:: | ||
$ sqlite3 example.db | ||
There are two tables: | ||
1. States: records every unique combination of benchmark + actions. For each | ||
entry, records an identifying state ID, the episode reward, and whether | ||
the episode is terminated: | ||
.. code-block:: | ||
sqlite> .mode markdown | ||
sqlite> .headers on | ||
sqlite> select * from States limit 5; | ||
| benchmark_uri | done | ir_instruction_count_oz_reward | state_id | actions | | ||
|--------------------------|------|--------------------------------|------------------------------------------|----------------| | ||
| generator://csmith-v0/99 | 0 | 0.0 | d625b874e58f6d357b816e21871297ac5c001cf0 | | | ||
| generator://csmith-v0/99 | 0 | 0.0 | d625b874e58f6d357b816e21871297ac5c001cf0 | 31 | | ||
| generator://csmith-v0/99 | 0 | 0.0 | 52f7142ef606d8b1dec2ff3371c7452c8d7b81ea | 31 116 | | ||
| generator://csmith-v0/99 | 0 | 0.268005818128586 | d8c05bd41b7a6c6157b6a8f0f5093907c7cc7ecf | 31 116 103 | | ||
| generator://csmith-v0/99 | 0 | 0.288621664047241 | c4d7ecd3807793a0d8bc281104c7f5a8aa4670f9 | 31 116 103 109 | | ||
2. Observations: records pickled, compressed, and text observation values | ||
for each unique state. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
env: LlvmEnv, | ||
db_path: Path, | ||
commit_frequency_in_seconds: int = 300, | ||
max_step_buffer_length: int = 5000, | ||
): | ||
super().__init__(env) | ||
if not hasattr(env, "unwrapped"): | ||
raise TypeError("Requires LlvmEnv base environment") | ||
if not isinstance(self.unwrapped, LlvmEnv): | ||
raise TypeError("Requires LlvmEnv base environment") | ||
db_path.parent.mkdir(exist_ok=True, parents=True) | ||
self.connection = sqlite3.connect(db_path) | ||
self.cursor = self.connection.cursor() | ||
self.commit_frequency = commit_frequency_in_seconds | ||
self.max_step_buffer_length = max_step_buffer_length | ||
|
||
self.cursor.executescript(DB_CREATION_SCRIPT) | ||
self.connection.commit() | ||
self.last_commit = time() | ||
|
||
self.observations_buffer = {} | ||
self.step_buffer = [] | ||
|
||
# House keeping notice: Keep these lists in sync with record(). | ||
self._observations = [ | ||
self.env.observation.spaces["IrSha1"], | ||
self.env.observation.spaces["Ir"], | ||
self.env.observation.spaces["Programl"], | ||
self.env.observation.spaces["Autophase"], | ||
self.env.observation.spaces["InstCount"], | ||
self.env.observation.spaces["IrInstructionCount"], | ||
] | ||
self._rewards = [ | ||
self.env.reward.spaces["IrInstructionCountOz"], | ||
self.env.reward.spaces["IrInstructionCount"], | ||
] | ||
self._reward_totals = np.zeros(len(self._rewards)) | ||
|
||
def flush(self) -> None: | ||
"""Flush the buffered steps and observations to database.""" | ||
n_steps, n_observations = len(self.step_buffer), len(self.observations_buffer) | ||
|
||
# Nothing to flush. | ||
if not n_steps: | ||
return | ||
|
||
with Timer() as flush_time: | ||
# House keeping notice: Keep these statements in sync with record(). | ||
self.cursor.executemany( | ||
"INSERT OR IGNORE INTO States VALUES (?, ?, ?, ?, ?)", | ||
self.step_buffer, | ||
) | ||
self.cursor.executemany( | ||
"INSERT OR IGNORE INTO Observations VALUES (?, ?, ?, ?, ?, ?)", | ||
((k, *v) for k, v in self.observations_buffer.items()), | ||
) | ||
self.step_buffer = [] | ||
self.observations_buffer = {} | ||
|
||
self.connection.commit() | ||
|
||
logging.info( | ||
"Wrote %d state records and %d observations in %s. Last flush %s ago", | ||
n_steps, | ||
n_observations, | ||
flush_time, | ||
humanize_duration(time() - self.last_commit), | ||
) | ||
self.last_commit = time() | ||
|
||
def reset(self, *args, **kwargs): | ||
observation = self.env.reset(*args, **kwargs) | ||
observations, rewards, done, info = self.env.multistep( | ||
actions=[], | ||
observation_spaces=self._observations, | ||
reward_spaces=self._rewards, | ||
) | ||
assert not done, f"reset() failed! {info}" | ||
self._reward_totals = np.array(rewards, dtype=np.float32) | ||
rewards = self._reward_totals | ||
self._record( | ||
actions=self.actions, | ||
observations=observations, | ||
rewards=self._reward_totals, | ||
done=False, | ||
) | ||
return observation | ||
|
||
def step( | ||
self, | ||
action: ActionType, | ||
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None, | ||
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None, | ||
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None, | ||
rewards: Optional[Iterable[Union[str, Reward]]] = None, | ||
): | ||
assert self.observation_space, "No observation space set" | ||
assert self.reward_space, "No reward space set" | ||
assert ( | ||
observation_spaces is None | ||
), "SynchronousSqliteLogger does not support observation_spaces" | ||
assert ( | ||
reward_spaces is None | ||
), "SynchronousSqliteLogger does not support reward_spaces" | ||
assert ( | ||
observations is None | ||
), "SynchronousSqliteLogger does not support observations" | ||
assert rewards is None, "SynchronousSqliteLogger does not support rewards" | ||
|
||
observations, rewards, done, info = self.env.step( | ||
action=action, | ||
observation_spaces=self._observations + [self.observation_space_spec], | ||
reward_spaces=self._rewards + [self.reward_space], | ||
) | ||
self._reward_totals += rewards[:-1] | ||
self._record( | ||
actions=self.actions, | ||
observations=observations[:-1], | ||
rewards=self._reward_totals, | ||
done=done, | ||
) | ||
return observations[-1], rewards[-1], done, info | ||
|
||
def _record(self, actions, observations, rewards, done) -> None: | ||
state_id, ir, programl, autophase, instcount, instruction_count = observations | ||
instruction_count_reward = float(rewards[0]) | ||
|
||
self.step_buffer.append( | ||
( | ||
str(self.benchmark.uri), | ||
1 if done else 0, | ||
instruction_count_reward, | ||
state_id, | ||
" ".join(str(x) for x in actions), | ||
) | ||
) | ||
|
||
self.observations_buffer[state_id] = ( | ||
instruction_count, | ||
zlib.compress(ir.encode("utf-8")), | ||
zlib.compress(pickle.dumps(programl)), | ||
" ".join(str(x) for x in autophase), | ||
" ".join(str(x) for x in instcount), | ||
) | ||
|
||
if ( | ||
len(self.step_buffer) >= self.max_step_buffer_length | ||
or time() - self.last_commit >= self.commit_frequency | ||
): | ||
self.flush() | ||
|
||
def close(self): | ||
self.flush() | ||
self.env.close() | ||
|
||
def fork(self): | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""Unit tests for //compiler_gym/wrappers.""" | ||
import pytest | ||
|
||
from compiler_gym.envs.llvm import LlvmEnv | ||
from compiler_gym.wrappers import CompilerEnvWrapper, SynchronousSqliteLogger | ||
from tests.test_main import main | ||
|
||
pytest_plugins = ["tests.pytest_plugins.llvm"] | ||
|
||
|
||
def test_SynchronousSqliteLogger_creates_file(env: LlvmEnv, tmp_path): | ||
db_path = tmp_path / "example.db" | ||
env.observation_space = "Autophase" | ||
env.reward_space = "IrInstructionCount" | ||
env = SynchronousSqliteLogger(env, db_path) | ||
env.reset() | ||
env.step(0) | ||
env.flush() | ||
assert db_path.is_file() | ||
|
||
|
||
def test_SynchronousSqliteLogger_requires_llvm_env(tmp_path): | ||
with pytest.raises(TypeError, match="Requires LlvmEnv base environment"): | ||
SynchronousSqliteLogger(1, tmp_path / "example.db") | ||
|
||
|
||
def test_SynchronousSqliteLogger_wrapped_env(env: LlvmEnv, tmp_path): | ||
env = CompilerEnvWrapper(env) | ||
env = SynchronousSqliteLogger(env, tmp_path / "example.db") | ||
env.reset() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |