Skip to content

Commit

Permalink
add cache to speed up manage_clvm/clvm_hex pre-commit check (Chia-Net…
Browse files Browse the repository at this point in the history
…work#14177)

* add cache to speed up manage_clvm/clvm_hex pre-commit check

* Update tools/manage_clvm.py

* add version to cache along with handling of incorrect versions

* force skipping the cache in ci
  • Loading branch information
altendky authored Jan 3, 2023
1 parent c826aec commit 026d467
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 20 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,6 @@ jobs:

- uses: chia-network/actions/activate-venv@main

- run: pre-commit run --all-files --verbose
- env:
CHIA_MANAGE_CLVM_CHECK_USE_CACHE: "false"
run: pre-commit run --all-files --verbose
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ build/
# https://pytest-monitor.readthedocs.io/en/latest/operating.html?highlight=.pymon#storage
.pymon

# cache for tooling such as tools/manage_clvm.py
.chia_cache/

# ===== =====
# DO NOT EDIT BELOW - GENERATED
Expand Down
156 changes: 137 additions & 19 deletions tools/manage_clvm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import dataclasses
import hashlib
import json
import os
import pathlib
import sys
Expand All @@ -13,6 +15,7 @@

here = pathlib.Path(__file__).parent.resolve()
root = here.parent
cache_path = root.joinpath(".chia_cache", "manage_clvm.json")

# This is a work-around for fixing imports so they get the appropriate top level
# packages instead of those of the same name in the same directory as this program.
Expand All @@ -34,6 +37,66 @@
top_levels = {"chia"}


class ManageClvmError(Exception):
pass


class CacheEntry(typing.TypedDict):
clvm: str
hex: str
hash: str


CacheEntries = typing.Dict[str, CacheEntry]
CacheVersion = typing.List[int]
current_cache_version: CacheVersion = [1]


class CacheVersionError(ManageClvmError):
pass


class NoCacheVersionError(CacheVersionError):
def __init__(self) -> None:
super().__init__("Cache must specify a version, none found")


class WrongCacheVersionError(CacheVersionError):
def __init__(self, found_version: object, expected_version: CacheVersion) -> None:
self.found_version = found_version
self.expected_version = expected_version
super().__init__(f"Cache has wrong version, expected {expected_version!r} got: {found_version!r}")


class Cache(typing.TypedDict):
entries: CacheEntries
version: CacheVersion


def create_empty_cache() -> Cache:
return {
"entries": {},
"version": current_cache_version,
}


def load_cache(file: typing.IO[str]) -> Cache:
loaded_cache = typing.cast(Cache, json.load(file))
try:
loaded_version = loaded_cache["version"]
except KeyError as e:
raise NoCacheVersionError() from e

if loaded_version != current_cache_version:
raise WrongCacheVersionError(found_version=loaded_version, expected_version=current_cache_version)

return loaded_cache


def dump_cache(cache: Cache, file: typing.IO[str]) -> None:
json.dump(cache, file, indent=4)


def generate_hash_bytes(hex_bytes: bytes) -> bytes:
cleaned_blob = bytes.fromhex(hex_bytes.decode("utf-8"))
serialize_program = SerializedProgram.from_bytes(cleaned_blob)
Expand Down Expand Up @@ -97,16 +160,57 @@ def find_stems(
return found_stems


def create_cache_entry(reference_paths: ClvmPaths, reference_bytes: ClvmBytes) -> CacheEntry:
source_bytes = reference_paths.clvm.read_bytes()

clvm_hasher = hashlib.sha256()
clvm_hasher.update(source_bytes)

hex_hasher = hashlib.sha256()
hex_hasher.update(reference_bytes.hex)

hash_hasher = hashlib.sha256()
hash_hasher.update(reference_bytes.hash)

return {
"clvm": clvm_hasher.hexdigest(),
"hex": hex_hasher.hexdigest(),
"hash": hash_hasher.hexdigest(),
}


@click.group()
def main() -> None:
pass


@main.command()
def check() -> int:
@click.option("--use-cache/--no-cache", default=True, show_default=True, envvar="USE_CACHE")
def check(use_cache: bool) -> int:
used_excludes = set()
overall_fail = False

cache: Cache
if not use_cache:
cache = create_empty_cache()
else:
try:
print(f"Attempting to load cache from: {cache_path}")
with cache_path.open(mode="r") as file:
cache = load_cache(file=file)
except FileNotFoundError:
print("Cache not found, starting fresh")
cache = create_empty_cache()
except NoCacheVersionError:
print("Ignoring cache due to lack of version")
cache = create_empty_cache()
except WrongCacheVersionError as e:
print(f"Ignoring cache due to incorrect version, expected {e.expected_version!r} got: {e.found_version!r}")
cache = create_empty_cache()

cache_entries = cache["entries"]
cache_modified = False

found_stems = find_stems(top_levels)
for name in ["hex", "hash"]:
found = found_stems[name]
Expand Down Expand Up @@ -134,27 +238,36 @@ def check() -> int:
file_fail = False
error = None

cache_key = str(stem_path)
try:
reference_paths = ClvmPaths.from_clvm(clvm=clvm_path)
reference_bytes = ClvmBytes.from_clvm_paths(paths=reference_paths)

with tempfile.TemporaryDirectory() as temporary_directory:
generated_paths = ClvmPaths.from_clvm(
clvm=pathlib.Path(temporary_directory).joinpath(f"generated{clvm_suffix}")
)

compile_clvm(
input_path=os.fspath(reference_paths.clvm),
output_path=os.fspath(generated_paths.hex),
search_paths=[os.fspath(reference_paths.clvm.parent)],
)

generated_bytes = ClvmBytes.from_hex_bytes(hex_bytes=generated_paths.hex.read_bytes())

if generated_bytes != reference_bytes:
file_fail = True
error = f" reference: {reference_bytes!r}\n"
error += f" generated: {generated_bytes!r}"
new_cache_entry = create_cache_entry(reference_paths=reference_paths, reference_bytes=reference_bytes)
existing_cache_entry = cache_entries.get(cache_key)
cache_hit = new_cache_entry == existing_cache_entry

if not cache_hit:
with tempfile.TemporaryDirectory() as temporary_directory:
generated_paths = ClvmPaths.from_clvm(
clvm=pathlib.Path(temporary_directory).joinpath(f"generated{clvm_suffix}")
)

compile_clvm(
input_path=os.fspath(reference_paths.clvm),
output_path=os.fspath(generated_paths.hex),
search_paths=[os.fspath(reference_paths.clvm.parent)],
)

generated_bytes = ClvmBytes.from_hex_bytes(hex_bytes=generated_paths.hex.read_bytes())

if generated_bytes != reference_bytes:
file_fail = True
error = f" reference: {reference_bytes!r}\n"
error += f" generated: {generated_bytes!r}"
else:
cache_modified = True
cache_entries[cache_key] = new_cache_entry
except Exception:
file_fail = True
error = traceback.format_exc()
Expand All @@ -178,6 +291,11 @@ def check() -> int:
for exclude in unused_excludes:
print(f" {exclude}")

if use_cache and cache_modified:
cache_path.parent.mkdir(parents=True, exist_ok=True)
with cache_path.open(mode="w") as file:
dump_cache(cache=cache, file=file)

return 1 if overall_fail else 0


Expand Down Expand Up @@ -229,4 +347,4 @@ def build() -> int:
return 1 if overall_fail else 0


sys.exit(main())
sys.exit(main(auto_envvar_prefix="CHIA_MANAGE_CLVM"))

0 comments on commit 026d467

Please sign in to comment.