Skip to content

Commit

Permalink
[wrappers] Add synchronous SQLite database wrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed May 12, 2022
1 parent 7d83233 commit 483f640
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler_gym/wrappers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ py_library(
"core.py",
"datasets.py",
"llvm.py",
"sqlite_logger.py",
"time_limit.py",
"validation.py",
],
Expand Down
1 change: 1 addition & 0 deletions compiler_gym/wrappers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ set(WRAPPERS_SRCS
)
if(COMPILER_GYM_ENABLE_LLVM_ENV)
list(APPEND WRAPPERS_SRCS "llvm.py")
list(APPEND WRAPPERS_SRCS "sqlite_logger.py")
endif()
if(COMPILER_GYM_ENABLE_MLIR_ENV)
list(APPEND WRAPPERS_SRCS "mlir.py")
Expand Down
2 changes: 2 additions & 0 deletions compiler_gym/wrappers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
IterateOverBenchmarks,
RandomOrderBenchmarks,
)
from compiler_gym.wrappers.sqlite_logger import SynchronousSqliteLogger

if config.enable_llvm_env:
from compiler_gym.wrappers.llvm import RuntimePointEstimateReward # noqa: F401
Expand All @@ -63,6 +64,7 @@
"ObservationWrapper",
"RandomOrderBenchmarks",
"RewardWrapper",
"SynchronousSqliteLogger",
"TimeLimit",
"ValidateBenchmarkAfterEveryStep",
]
Expand Down
245 changes: 245 additions & 0 deletions compiler_gym/wrappers/sqlite_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""This module implements a wrapper that logs state transitions to an sqlite
database.
"""
import logging
import pickle
import sqlite3
import zlib
from pathlib import Path
from time import time
from typing import Iterable, Optional, Union

import numpy as np

from compiler_gym.envs import LlvmEnv
from compiler_gym.spaces import Reward
from compiler_gym.util.gym_type_hints import ActionType
from compiler_gym.util.timer import Timer, humanize_duration
from compiler_gym.views import ObservationSpaceSpec
from compiler_gym.wrappers import CompilerEnvWrapper

DB_CREATION_SCRIPT = """
CREATE TABLE IF NOT EXISTS States (
benchmark_uri TEXT NOT NULL, -- The URI of the benchmark.
done INTEGER NOT NULL, -- 0 = False, 1 = True.
ir_instruction_count_oz_reward REAL NULLABLE,
state_id TEXT NOT NULL, -- 40-char sha1.
actions TEXT NOT NULL, -- Decode: [int(x) for x in field.split()]
PRIMARY KEY (benchmark_uri, actions),
FOREIGN KEY (state_id) REFERENCES Observations(state_id) ON UPDATE CASCADE
);
CREATE TABLE IF NOT EXISTS Observations (
state_id TEXT NOT NULL, -- 40-char sha1.
ir_instruction_count INTEGER NOT NULL,
compressed_llvm_ir BLOB NOT NULL, -- Decode: zlib.decompress(...)
pickled_compressed_programl BLOB NOT NULL, -- Decode: pickle.loads(zlib.decompress(...))
autophase TEXT NOT NULL, -- Decode: np.array([int(x) for x in field.split()], dtype=np.int64)
instcount TEXT NOT NULL, -- Decode: np.array([int(x) for x in field.split()], dtype=np.int64)
PRIMARY KEY (state_id)
);
"""


class SynchronousSqliteLogger(CompilerEnvWrapper):
"""A wrapper for an LLVM environment that logs all transitions to an sqlite
database.
Wrap an existing LLVM environment and then use it as per normal:
>>> env = SynchronousSqliteLogger(
... env=gym.make("llvm-autophase-ic-v0"),
... db_path="example.db",
... )
Connect to the database file you specified:
.. code-block::
$ sqlite3 example.db
There are two tables:
1. States: records every unique combination of benchmark + actions. For each
entry, records an identifying state ID, the episode reward, and whether
the episode is terminated:
.. code-block::
sqlite> .mode markdown
sqlite> .headers on
sqlite> select * from States limit 5;
| benchmark_uri | done | ir_instruction_count_oz_reward | state_id | actions |
|--------------------------|------|--------------------------------|------------------------------------------|----------------|
| generator://csmith-v0/99 | 0 | 0.0 | d625b874e58f6d357b816e21871297ac5c001cf0 | |
| generator://csmith-v0/99 | 0 | 0.0 | d625b874e58f6d357b816e21871297ac5c001cf0 | 31 |
| generator://csmith-v0/99 | 0 | 0.0 | 52f7142ef606d8b1dec2ff3371c7452c8d7b81ea | 31 116 |
| generator://csmith-v0/99 | 0 | 0.268005818128586 | d8c05bd41b7a6c6157b6a8f0f5093907c7cc7ecf | 31 116 103 |
| generator://csmith-v0/99 | 0 | 0.288621664047241 | c4d7ecd3807793a0d8bc281104c7f5a8aa4670f9 | 31 116 103 109 |
2. Observations: records pickled, compressed, and text observation values
for each unique state.
"""

def __init__(
self,
env: LlvmEnv,
db_path: Path,
commit_frequency_in_seconds: int = 300,
max_step_buffer_length: int = 5000,
):
super().__init__(env)
if not hasattr(env, "unwrapped"):
raise TypeError("Requires LlvmEnv base environment")
if not isinstance(self.unwrapped, LlvmEnv):
raise TypeError("Requires LlvmEnv base environment")
db_path.parent.mkdir(exist_ok=True, parents=True)
self.connection = sqlite3.connect(db_path)
self.cursor = self.connection.cursor()
self.commit_frequency = commit_frequency_in_seconds
self.max_step_buffer_length = max_step_buffer_length

self.cursor.executescript(DB_CREATION_SCRIPT)
self.connection.commit()
self.last_commit = time()

self.observations_buffer = {}
self.step_buffer = []

# House keeping notice: Keep these lists in sync with record().
self._observations = [
self.env.observation.spaces["IrSha1"],
self.env.observation.spaces["Ir"],
self.env.observation.spaces["Programl"],
self.env.observation.spaces["Autophase"],
self.env.observation.spaces["InstCount"],
self.env.observation.spaces["IrInstructionCount"],
]
self._rewards = [
self.env.reward.spaces["IrInstructionCountOz"],
self.env.reward.spaces["IrInstructionCount"],
]
self._reward_totals = np.zeros(len(self._rewards))

def flush(self) -> None:
"""Flush the buffered steps and observations to database."""
n_steps, n_observations = len(self.step_buffer), len(self.observations_buffer)

# Nothing to flush.
if not n_steps:
return

with Timer() as flush_time:
# House keeping notice: Keep these statements in sync with record().
self.cursor.executemany(
"INSERT OR IGNORE INTO States VALUES (?, ?, ?, ?, ?)",
self.step_buffer,
)
self.cursor.executemany(
"INSERT OR IGNORE INTO Observations VALUES (?, ?, ?, ?, ?, ?)",
((k, *v) for k, v in self.observations_buffer.items()),
)
self.step_buffer = []
self.observations_buffer = {}

self.connection.commit()

logging.info(
"Wrote %d state records and %d observations in %s. Last flush %s ago",
n_steps,
n_observations,
flush_time,
humanize_duration(time() - self.last_commit),
)
self.last_commit = time()

def reset(self, *args, **kwargs):
observation = self.env.reset(*args, **kwargs)
observations, rewards, done, info = self.env.multistep(
actions=[],
observation_spaces=self._observations,
reward_spaces=self._rewards,
)
assert not done, f"reset() failed! {info}"
self._reward_totals = np.array(rewards, dtype=np.float32)
rewards = self._reward_totals
self._record(
actions=self.actions,
observations=observations,
rewards=self._reward_totals,
done=False,
)
return observation

def step(
self,
action: ActionType,
observation_spaces: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
reward_spaces: Optional[Iterable[Union[str, Reward]]] = None,
observations: Optional[Iterable[Union[str, ObservationSpaceSpec]]] = None,
rewards: Optional[Iterable[Union[str, Reward]]] = None,
):
assert self.observation_space, "No observation space set"
assert self.reward_space, "No reward space set"
assert (
observation_spaces is None
), "SynchronousSqliteLogger does not support observation_spaces"
assert (
reward_spaces is None
), "SynchronousSqliteLogger does not support reward_spaces"
assert (
observations is None
), "SynchronousSqliteLogger does not support observations"
assert rewards is None, "SynchronousSqliteLogger does not support rewards"

observations, rewards, done, info = self.env.step(
action=action,
observation_spaces=self._observations + [self.observation_space_spec],
reward_spaces=self._rewards + [self.reward_space],
)
self._reward_totals += rewards[:-1]
self._record(
actions=self.actions,
observations=observations[:-1],
rewards=self._reward_totals,
done=done,
)
return observations[-1], rewards[-1], done, info

def _record(self, actions, observations, rewards, done) -> None:
state_id, ir, programl, autophase, instcount, instruction_count = observations
instruction_count_reward = float(rewards[0])

self.step_buffer.append(
(
str(self.benchmark.uri),
1 if done else 0,
instruction_count_reward,
state_id,
" ".join(str(x) for x in actions),
)
)

self.observations_buffer[state_id] = (
instruction_count,
zlib.compress(ir.encode("utf-8")),
zlib.compress(pickle.dumps(programl)),
" ".join(str(x) for x in autophase),
" ".join(str(x) for x in instcount),
)

if (
len(self.step_buffer) >= self.max_step_buffer_length
or time() - self.last_commit >= self.commit_frequency
):
self.flush()

def close(self):
self.flush()
self.env.close()

def fork(self):
raise NotImplementedError
11 changes: 11 additions & 0 deletions tests/wrappers/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,17 @@ py_test(
],
)

py_test(
name = "sqlite_logger_test",
timeout = "short",
srcs = ["sqlite_logger_test.py"],
deps = [
"//compiler_gym/wrappers",
"//tests:test_main",
"//tests/pytest_plugins:llvm",
],
)

py_test(
name = "time_limit_wrappers_test",
timeout = "short",
Expand Down
38 changes: 38 additions & 0 deletions tests/wrappers/sqlite_logger_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Unit tests for //compiler_gym/wrappers."""
import pytest

from compiler_gym.envs.llvm import LlvmEnv
from compiler_gym.wrappers import CompilerEnvWrapper, SynchronousSqliteLogger
from tests.test_main import main

pytest_plugins = ["tests.pytest_plugins.llvm"]


def test_SynchronousSqliteLogger_creates_file(env: LlvmEnv, tmp_path):
db_path = tmp_path / "example.db"
env.observation_space = "Autophase"
env.reward_space = "IrInstructionCount"
env = SynchronousSqliteLogger(env, db_path)
env.reset()
env.step(0)
env.flush()
assert db_path.is_file()


def test_SynchronousSqliteLogger_requires_llvm_env(tmp_path):
with pytest.raises(TypeError, match="Requires LlvmEnv base environment"):
SynchronousSqliteLogger(1, tmp_path / "example.db")


def test_SynchronousSqliteLogger_wrapped_env(env: LlvmEnv, tmp_path):
env = CompilerEnvWrapper(env)
env = SynchronousSqliteLogger(env, tmp_path / "example.db")
env.reset()


if __name__ == "__main__":
main()

0 comments on commit 483f640

Please sign in to comment.