From 9c811c3cdba3d9e256a001caf36903e9f463dc8f Mon Sep 17 00:00:00 2001 From: shkoo Date: Mon, 12 Sep 2022 21:24:07 -0700 Subject: [PATCH] Add initial support to profile guest runs (#274) * Add initial support to profile guest runs This is accessible with the --prof-out option of r0vm, or by using risc0_zkvm::host::profiler::Profiler. It produces output in a format suitable for use with google's pprof, but does not yet supply full stack traces. --- Cargo-guest.lock | 4 +- Cargo-host.lock | 258 +++++++++++++--- risc0/zkvm/r0vm/BUILD.bazel | 11 - risc0/zkvm/r0vm/Cargo.toml | 7 +- risc0/zkvm/r0vm/src/bin/r0vm.rs | 141 +++++++-- risc0/zkvm/sdk/cpp/guest/test/BUILD.bazel | 4 +- risc0/zkvm/sdk/cpp/guest/test/test.cpp | 4 +- risc0/zkvm/sdk/rust/BUILD.bazel | 1 + risc0/zkvm/sdk/rust/Cargo.toml | 8 + risc0/zkvm/sdk/rust/build.rs | 6 + risc0/zkvm/sdk/rust/methods/BUILD.bazel | 24 +- risc0/zkvm/sdk/rust/methods/Cargo.toml | 3 + risc0/zkvm/sdk/rust/methods/inner/Cargo.toml | 4 + .../src/bin/{sha_accel.rs => multi_test.rs} | 43 ++- .../sdk/rust/methods/inner/src/bin/verify.rs | 4 + risc0/zkvm/sdk/rust/methods/src/lib.rs | 3 +- risc0/zkvm/sdk/rust/methods/src/multi_test.rs | 26 ++ risc0/zkvm/sdk/rust/src/host/mod.rs | 230 +++++++++++++- risc0/zkvm/sdk/rust/src/host/profile.proto | 212 +++++++++++++ risc0/zkvm/sdk/rust/src/host/profiler.rs | 291 ++++++++++++++++++ risc0/zkvm/sdk/rust/src/prove/exec.rs | 38 ++- risc0/zkvm/sdk/rust/src/prove/mod.rs | 20 +- 22 files changed, 1216 insertions(+), 126 deletions(-) rename risc0/zkvm/sdk/rust/methods/inner/src/bin/{sha_accel.rs => multi_test.rs} (64%) create mode 100644 risc0/zkvm/sdk/rust/methods/src/multi_test.rs create mode 100644 risc0/zkvm/sdk/rust/src/host/profile.proto create mode 100644 risc0/zkvm/sdk/rust/src/host/profiler.rs diff --git a/Cargo-guest.lock b/Cargo-guest.lock index 957cc73c81..fe378fcd77 100644 --- a/Cargo-guest.lock +++ b/Cargo-guest.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "anyhow" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26fa4d7e3f2eebadf743988fc8aec9fa9a9e82611acafd77c1462ed6262440a" +checksum = "b9a8f622bcf6ff3df478e9deba3e03e4e04b300f8e6a139e192c05fa3490afc7" [[package]] name = "array-init" diff --git a/Cargo-host.lock b/Cargo-host.lock index c3286990b3..2fbd447c45 100644 --- a/Cargo-host.lock +++ b/Cargo-host.lock @@ -2,6 +2,20 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ca9b76e919fd83ccfb509f51b28c333c0e03f2221616e347a129215cec4e4a9" +dependencies = [ + "cpp_demangle", + "fallible-iterator", + "gimli", + "object", + "rustc-demangle", + "smallvec", +] + [[package]] name = "adler" version = "1.0.2" @@ -31,9 +45,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.63" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26fa4d7e3f2eebadf743988fc8aec9fa9a9e82611acafd77c1462ed6262440a" +checksum = "b9a8f622bcf6ff3df478e9deba3e03e4e04b300f8e6a139e192c05fa3490afc7" [[package]] name = "array-init" @@ -86,6 +100,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "autotools" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8138adefca3e5d2e73bfba83bd6eeaf904b26a7ac1b4a19892cfe16cc7e1701" +dependencies = [ + "cc", +] + [[package]] name = "base64" version = "0.13.0" @@ -106,9 +129,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "block-buffer" -version = "0.10.2" +version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bf7fe51849ea569fd452f37822f606a5cabb684dc918707a0193fd4664ff324" +checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" dependencies = [ "generic-array", ] @@ -333,11 +356,20 @@ version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5827cebf4670468b8772dd191856768aedcb1b0278a04f989f7766351917b9dc" +[[package]] +name = "cpp_demangle" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eeaa953eaad386a53111e47172c2fedba671e5684c8dd601a5f474f4f118710f" +dependencies = [ + "cfg-if", +] + [[package]] name = "cpufeatures" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc948ebb96241bb40ab73effeb80d9f93afaad49359d159a5e61be51619fe813" +checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" dependencies = [ "libc", ] @@ -573,6 +605,12 @@ dependencies = [ "termcolor", ] +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + [[package]] name = "fastrand" version = "1.8.0" @@ -582,6 +620,12 @@ dependencies = [ "instant", ] +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + [[package]] name = "flate2" version = "1.0.24" @@ -615,11 +659,10 @@ checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" -version = "1.0.1" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5fc25a87fa4fd2094bffb06925852034d90a17f0d1e05197d4956d3555752191" +checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" dependencies = [ - "matches", "percent-encoding", ] @@ -692,6 +735,17 @@ dependencies = [ "wasi", ] +[[package]] +name = "gimli" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22030e2c5a68ec659fde1e949a745124b48e6fa8b045b7ed5bd1fe4ccc5c4e5d" +dependencies = [ + "fallible-iterator", + "indexmap", + "stable_deref_trait", +] + [[package]] name = "glob" version = "0.3.0" @@ -869,11 +923,10 @@ dependencies = [ [[package]] name = "idna" -version = "0.2.3" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "418a0a6fab821475f634efe3ccc45c013f742efe03d853e8d3355d5cb850ecf8" +checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" dependencies = [ - "matches", "unicode-bidi", "unicode-normalization", ] @@ -953,9 +1006,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.59" +version = "0.3.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "258451ab10b34f8af53416d1fdab72c22e805f0c92a1136d59470ec0b11138b2" +checksum = "49409df3e3bf0856b916e2ceaca09ee28e6871cf7d9ce97a692cacfdb2a25a47" dependencies = [ "wasm-bindgen", ] @@ -1016,12 +1069,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "matches" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3e378b66a060d48947b590737b30a1be76706c8dd7b8ba0f2fe3989c68a853f" - [[package]] name = "matrixmultiply" version = "0.3.2" @@ -1054,9 +1101,9 @@ checksum = "2a60c7ce501c71e03a9c9c0d35b861413ae925bd979cc7a4e30d060069aaac8d" [[package]] name = "miniz_oxide" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f5c75688da582b8ffc1f1799e9db273f32133c49e048f614d22ec3256773ccc" +checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34" dependencies = [ "adler", ] @@ -1073,6 +1120,12 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "multimap" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" + [[package]] name = "native-tls" version = "0.2.10" @@ -1152,6 +1205,16 @@ dependencies = [ "libc", ] +[[package]] +name = "object" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21158b2c33aa6d4561f1c0a6ea283ca92bc54802a93b263e910746d679a7eb53" +dependencies = [ + "flate2", + "memchr", +] + [[package]] name = "once_cell" version = "1.14.0" @@ -1252,9 +1315,19 @@ dependencies = [ [[package]] name = "percent-encoding" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4fd5641d01c8f18a23da7b6fe29298ff4b55afcccdf78973b24cf3175fee32e" +checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" + +[[package]] +name = "petgraph" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6d5014253a1331579ce62aa67443b4a658c5e7dd03d4bc6d302b94474888143" +dependencies = [ + "fixedbitset", + "indexmap", +] [[package]] name = "pin-project-lite" @@ -1276,9 +1349,9 @@ checksum = "1df8c4ec4b0627e53bdf214615ad287367e482558cf84b109250b37464dc03ae" [[package]] name = "plotters" -version = "0.3.3" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "716b4eeb6c4a1d3ecc956f75b43ec2e8e8ba80026413e70a3f41fd3313d3492b" +checksum = "2538b639e642295546c50fcd545198c9d64ee2a38620a628724a3b266d5fbf97" dependencies = [ "num-traits", "plotters-backend", @@ -1368,6 +1441,68 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "prost" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "399c3c31cdec40583bb68f0b18403400d01ec4289c383aa047560439952c4dd7" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f835c582e6bd972ba8347313300219fed5bfa52caf175298d860b61ff6069bb" +dependencies = [ + "bytes", + "heck", + "itertools", + "lazy_static", + "log", + "multimap", + "petgraph", + "prost", + "prost-types", + "regex", + "tempfile", + "which", +] + +[[package]] +name = "prost-derive" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7345d5f0e08c0536d7ac7229952590239e77abf0a0100a1b1d890add6ea96364" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "prost-types" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dfaa718ad76a44b3415e6c4d53b17c8f99160dcb3a99b10470fce8ad43f6e3e" +dependencies = [ + "bytes", + "prost", +] + +[[package]] +name = "protobuf-src" +version = "1.1.0+21.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7ac8852baeb3cc6fb83b93646fb93c0ffe5d14bf138c945ceb4b9948ee0e3c1" +dependencies = [ + "autotools", +] + [[package]] name = "quote" version = "1.0.21" @@ -1612,15 +1747,20 @@ dependencies = [ name = "risc0-zkvm" version = "0.11.1" dependencies = [ + "addr2line", "anyhow", "bytemuck", "criterion", "ctor", "cxx", "env_logger", + "gimli", "lazy-regex", "lazy_static", "log", + "prost", + "prost-build", + "protobuf-src", "rand", "risc0-zkp", "risc0-zkvm-circuit", @@ -1701,6 +1841,12 @@ dependencies = [ "tbb-sys", ] +[[package]] +name = "rustc-demangle" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" + [[package]] name = "rustc-std-workspace-core" version = "1.0.0" @@ -1891,6 +2037,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "smallvec" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fd0db749597d91ff862fd1d55ea87f7855a744a8425a64695b6fca237d1dad1" + [[package]] name = "socket2" version = "0.4.7" @@ -1907,6 +2059,12 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "strsim" version = "0.10.0" @@ -2088,9 +2246,9 @@ dependencies = [ [[package]] name = "tokio-util" -version = "0.7.3" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc463cd8deddc3770d20f9852143d50bf6094e640b485cb2e189a2099085ff45" +checksum = "0bb2e075f03b3d66d8d8785356224ba688d2906a371015e225beeb65ca92c740" dependencies = [ "bytes", "futures-core", @@ -2173,13 +2331,12 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "url" -version = "2.2.2" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a507c383b2d33b5fc35d1861e77e6b383d158b2da5e14fe51b83dfedf6fd578c" +checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" dependencies = [ "form_urlencoded", "idna", - "matches", "percent-encoding", ] @@ -2233,9 +2390,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.82" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc7652e3f6c4706c8d9cd54832c4a4ccb9b5336e2c3bd154d5cccfbf1c1f5f7d" +checksum = "eaf9f5aceeec8be17c128b2e93e031fb8a4d469bb9c4ae2d7dc1888b26887268" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -2243,9 +2400,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.82" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "662cd44805586bd52971b9586b1df85cdbbd9112e4ef4d8f41559c334dc6ac3f" +checksum = "4c8ffb332579b0557b52d268b91feab8df3615f265d5270fec2a8c95b17c1142" dependencies = [ "bumpalo", "log", @@ -2258,9 +2415,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.32" +version = "0.4.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa76fb221a1f8acddf5b54ace85912606980ad661ac7a503b4570ffd3a624dad" +checksum = "23639446165ca5a5de86ae1d8896b737ae80319560fbaa4c2887b7da6e7ebd7d" dependencies = [ "cfg-if", "js-sys", @@ -2270,9 +2427,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.82" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b260f13d3012071dfb1512849c033b1925038373aea48ced3012c09df952c602" +checksum = "052be0f94026e6cbc75cdefc9bae13fd6052cdcaf532fa6c45e7ae33a1e6c810" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2280,9 +2437,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.82" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5be8e654bdd9b79216c2929ab90721aa82faf65c48cdf08bdc4e7f51357b80da" +checksum = "07bc0c051dc5f23e307b13285f9d75df86bfdf816c5721e573dec1f9b8aa193c" dependencies = [ "proc-macro2", "quote", @@ -2293,15 +2450,15 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.82" +version = "0.2.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6598dd0bd3c7d51095ff6531a5b23e02acdc81804e30d8f07afb77b7215a140a" +checksum = "1c38c045535d93ec4f0b4defec448e4291638ee608530863b1e2ba115d4fff7f" [[package]] name = "web-sys" -version = "0.3.59" +version = "0.3.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed055ab27f941423197eb86b2035720b1a3ce40504df082cac2ecc6ed73335a1" +checksum = "bcda906d8be16e728fd5adc5b729afad4e444e106ab28cd1c7256e54fa61510f" dependencies = [ "js-sys", "wasm-bindgen", @@ -2326,6 +2483,17 @@ dependencies = [ "webpki", ] +[[package]] +name = "which" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c831fbbee9e129a8cf93e7747a82da9d95ba8e16621cae60ec2cdc849bacb7b" +dependencies = [ + "either", + "libc", + "once_cell", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/risc0/zkvm/r0vm/BUILD.bazel b/risc0/zkvm/r0vm/BUILD.bazel index 96dcbfc5ac..6c8e2071e2 100644 --- a/risc0/zkvm/r0vm/BUILD.bazel +++ b/risc0/zkvm/r0vm/BUILD.bazel @@ -1,14 +1,3 @@ load("@rules_rust//rust:defs.bzl", "rust_binary") exports_files(["Cargo.toml"]) - -rust_binary( - name = "r0vm", - srcs = ["src/bin/r0vm.rs"], - deps = [ - "//risc0/zkvm/sdk/rust:zkvm_host", - "@crates_host//:bytemuck", - "@crates_host//:clap", - "@crates_host//:env_logger", - ], -) diff --git a/risc0/zkvm/r0vm/Cargo.toml b/risc0/zkvm/r0vm/Cargo.toml index d77ed1ecfd..753d977fe3 100644 --- a/risc0/zkvm/r0vm/Cargo.toml +++ b/risc0/zkvm/r0vm/Cargo.toml @@ -8,13 +8,18 @@ homepage = "https://risczero.com/" repository = "https://github.com/risc0/risc0/" [dependencies] +anyhow = "1.0" bytemuck = "1.12" clap = { version = "3.2", features = ["derive"] } env_logger = "0.9.0" -risc0-zkvm = { version = "0.11", path = "../sdk/rust" } +risc0-zkvm = { version = "0.11", path = "../sdk/rust", features = ["profiler"] } [dev-dependencies] anyhow = "1.0" assert_cmd = "2.0" assert_fs = "1.0" risc0-zkvm-methods = { path = "../sdk/rust/methods" } + +[features] +pure-prove = ["risc0-zkvm/pure-prove"] + diff --git a/risc0/zkvm/r0vm/src/bin/r0vm.rs b/risc0/zkvm/r0vm/src/bin/r0vm.rs index a92ecd343a..26e588c35a 100644 --- a/risc0/zkvm/r0vm/src/bin/r0vm.rs +++ b/risc0/zkvm/r0vm/src/bin/r0vm.rs @@ -13,26 +13,39 @@ // limitations under the License. use std::default::Default; -use std::{fs, io::Write}; +use std::{ + fs, + io::Write, + path::{Path, PathBuf}, +}; +use anyhow::Result; use clap::Parser; -use risc0_zkvm::host::{MethodId, Prover, ProverOpts, Receipt, DEFAULT_METHOD_ID_LIMIT}; +use risc0_zkvm::host::{ + profiler::Profiler, MethodId, Prover, ProverOpts, Receipt, DEFAULT_METHOD_ID_LIMIT, +}; /// Generates a MethodID for a given RISC-V ELF binary. #[derive(Parser)] #[clap(about, version, author)] struct Args { /// The ELF file to run - #[clap(long)] - elf: String, + #[clap(long, parse(from_os_str))] + elf: PathBuf, /// MethodID file; created if needed and it doesn't exist. - #[clap(long)] - method_id: Option, + #[clap(long, parse(from_os_str))] + method_id: Option, /// Receipt output file. + #[clap(long, parse(from_os_str))] + receipt: Option, + + /// EXPERIMENTAL: When enabled, writes the receipt in a format usable by the + /// "verify" guest method. + #[cfg(feature = "pure-prove")] #[clap(long)] - receipt: Option, + input_for_verify: bool, /// Skip generating the seal in receipt. This should only be used /// for testing. In this case, performace will be much better but @@ -41,8 +54,8 @@ struct Args { skip_seal: bool, /// File to read initial input from. - #[clap(long)] - initial_input: Option, + #[clap(long, parse(from_os_str))] + initial_input: Option, /// Display verbose output. #[clap(short, long, action = clap::ArgAction::Count)] @@ -51,13 +64,15 @@ struct Args { /// Limit the number of hash table entries to compute. #[clap(short, long, default_value_t = DEFAULT_METHOD_ID_LIMIT)] limit: u32, + + /// Write "pprof" protobuf output of the guest's run to this file. + /// You can use google's pprof (https://github.com/google/pprof) + /// to read it. + #[clap(long, parse(from_os_str))] + pprof_out: Option, } -fn read_method_id( - verbose: u8, - elf_file: &str, - method_id_file: &Option, -) -> Option { +fn read_method_id(verbose: u8, elf_file: &Path, method_id_file: Option<&Path>) -> Option { let elf_mtime = fs::metadata(elf_file).ok()?.modified().ok()?; let id_mtime = fs::metadata(method_id_file.as_ref()?) .ok()? @@ -76,13 +91,44 @@ fn read_method_id( if verbose > 0 { println!( "Successfully read method id from {}", - method_id_file.as_ref().unwrap() + method_id_file.unwrap().display() ); } Some(id) } +fn run_prover( + elf_contents: &[u8], + method_id: &[u8], + opts: ProverOpts, + initial_input: Option>, +) -> Result<(Receipt, Vec)> { + let mut prover = Prover::new_with_opts(&elf_contents, method_id, opts).unwrap(); + if let Some(bytes) = initial_input { + prover.add_input_u8_slice(bytes.as_slice()); + } + let receipt = prover.run()?; + let output = prover.get_output()?; + Ok((receipt, output.to_vec())) +} + +fn encode_receipt(receipt: &Receipt, _method_id: &[u8], _args: &Args) -> Vec { + #[cfg(feature = "pure-prove")] + if _args.input_for_verify { + let mut encoded: Vec = Vec::new(); + let mut add_input_u32_slice = + |slice: &[u32]| encoded.write_all(bytemuck::cast_slice(slice)).unwrap(); + add_input_u32_slice(&[receipt.seal.len() as u32]); + add_input_u32_slice(&receipt.seal); + add_input_u32_slice(&[(_method_id.len() / 4) as u32]); + encoded.write_all(_method_id).unwrap(); + return encoded; + } + + bytemuck::cast_slice(risc0_zkvm::serde::to_vec(&receipt).unwrap().as_slice()).into() +} + fn main() { env_logger::init(); @@ -93,7 +139,7 @@ fn main() { eprintln!( "Read {} bytes of ELF from {}", elf_contents.len(), - &args.elf + args.elf.display() ); } @@ -102,15 +148,20 @@ fn main() { // generate an actual proof. MethodId::from_slice(&[]).unwrap() } else { - read_method_id(args.verbose, &args.elf, &args.method_id).unwrap_or_else(|| { + read_method_id( + args.verbose, + &args.elf, + args.method_id.as_ref().map(|p| p.as_path()), + ) + .unwrap_or_else(|| { if args.verbose > 0 { eprintln!("Computing method id"); } let computed = MethodId::compute_with_limit(&elf_contents, args.limit).unwrap(); - if let Some(method_id_file) = args.method_id { + if let Some(method_id_file) = args.method_id.as_ref() { std::fs::write(&method_id_file, computed.as_slice().unwrap()).unwrap(); if args.verbose > 0 { - eprintln!("Saved method id to {}", method_id_file); + eprintln!("Saved method id to {}", method_id_file.display()); } } computed @@ -120,18 +171,39 @@ fn main() { let opts: ProverOpts = ProverOpts::default().with_skip_seal(args.skip_seal || args.receipt.is_none()); - let mut prover = - Prover::new_with_opts(&elf_contents, method_id.as_slice().unwrap(), opts).unwrap(); - if let Some(input) = args.initial_input { - let input_bytes = fs::read(input).unwrap(); - if args.verbose > 0 { - eprintln!("Supplying {} bytes of initial input", input_bytes.len()); - } - prover.add_input_u8_slice(&input_bytes); + let mut guest_prof: Option = None; + + if args.pprof_out.is_some() { + guest_prof = Some(Profiler::new(args.elf.to_str().unwrap(), &elf_contents).unwrap()); } - let receipt: Receipt = prover.run().unwrap(); - let receipt_data = risc0_zkvm::serde::to_vec(&receipt).unwrap(); + let proof = run_prover( + &elf_contents, + method_id.as_slice().unwrap(), + if let Some(ref mut profiler) = guest_prof { + opts.with_trace_callback(profiler.make_trace_callback()) + } else { + opts + }, + args.initial_input.as_ref().map(|input| { + let input_bytes = fs::read(input).unwrap(); + if args.verbose > 0 { + eprintln!("Supplying {} bytes of initial input", input_bytes.len()); + } + input_bytes + }), + ); + + // Now that we're done with the prover, we can collect the guest profiling data. + if let Some(ref mut profiler) = guest_prof.as_mut() { + profiler.finalize(); + let report = profiler.encode_to_vec(); + fs::write(args.pprof_out.as_ref().unwrap(), &report) + .expect("Unable to write profiling output"); + } + let (receipt, output) = proof.expect("Run failed"); + + let receipt_data = encode_receipt(&receipt, method_id.as_slice().unwrap(), &args); if args.skip_seal || args.receipt.is_none() { if args.verbose > 0 { @@ -143,19 +215,20 @@ fn main() { receipt.verify(method_id.as_slice().unwrap()).unwrap(); } } - if let Some(receipt_file) = args.receipt { - fs::write(&receipt_file, bytemuck::cast_slice(&receipt_data)).unwrap(); + + if let Some(receipt_file) = args.receipt.as_ref() { + fs::write(receipt_file, receipt_data.as_slice()).expect("Unable to write receipt file"); if args.verbose > 0 { eprintln!( "Wrote {} bytes of receipt to {}", receipt_data.len(), - receipt_file + receipt_file.display() ); } } - let output = prover.get_output().unwrap(); + if args.verbose > 0 { eprintln!("Writing {} bytes of output to stdout", output.len()); } - std::io::stdout().write_all(output).unwrap(); + std::io::stdout().write_all(output.as_slice()).unwrap(); } diff --git a/risc0/zkvm/sdk/cpp/guest/test/BUILD.bazel b/risc0/zkvm/sdk/cpp/guest/test/BUILD.bazel index 485eec1f1f..077e02142a 100644 --- a/risc0/zkvm/sdk/cpp/guest/test/BUILD.bazel +++ b/risc0/zkvm/sdk/cpp/guest/test/BUILD.bazel @@ -45,10 +45,10 @@ cc_gtest( "//risc0/zkvm/sdk/rust/methods:test_fail.id", "//risc0/zkvm/sdk/rust/methods:test_mem", "//risc0/zkvm/sdk/rust/methods:test_mem.id", + "//risc0/zkvm/sdk/rust/methods:test_multi_test", + "//risc0/zkvm/sdk/rust/methods:test_multi_test.id", "//risc0/zkvm/sdk/rust/methods:test_sha", "//risc0/zkvm/sdk/rust/methods:test_sha.id", - "//risc0/zkvm/sdk/rust/methods:test_sha_accel", - "//risc0/zkvm/sdk/rust/methods:test_sha_accel.id", ], tags = ["exclusive"], deps = ["//risc0/zkvm/sdk/cpp/host"], diff --git a/risc0/zkvm/sdk/cpp/guest/test/test.cpp b/risc0/zkvm/sdk/cpp/guest/test/test.cpp index b753fcc819..d817f1b696 100644 --- a/risc0/zkvm/sdk/cpp/guest/test/test.cpp +++ b/risc0/zkvm/sdk/cpp/guest/test/test.cpp @@ -204,8 +204,8 @@ TEST(CoreTests, Memset) { } TEST(CoreTests, SHAAccel) { - MethodId methodId = loadMethodId("risc0/zkvm/sdk/rust/methods/test_sha_accel.id"); - Prover prover("risc0/zkvm/sdk/rust/methods/test_sha_accel", methodId); + MethodId methodId = loadMethodId("risc0/zkvm/sdk/rust/methods/test_multi_test.id"); + Prover prover("risc0/zkvm/sdk/rust/methods/test_multi_test", methodId); prover.writeInput(0); // Test risc0_zkvm_guest::sha::Impl prover.writeInput(0); // Compute an empty digest Receipt receipt = prover.run(); diff --git a/risc0/zkvm/sdk/rust/BUILD.bazel b/risc0/zkvm/sdk/rust/BUILD.bazel index 36bc902ca1..6583dec683 100644 --- a/risc0/zkvm/sdk/rust/BUILD.bazel +++ b/risc0/zkvm/sdk/rust/BUILD.bazel @@ -25,6 +25,7 @@ risc0_rust_library_pair( "//risc0/zkp/rust:zkp_host", "//risc0/zkvm/sdk/cpp/host", "//risc0/zkvm/sdk/rust/platform:platform_host", + "//risc0/zkvm/sdk/rust/methods:methods_host", "@crates_host//:anyhow", "@crates_host//:bytemuck", "@crates_host//:cxx", diff --git a/risc0/zkvm/sdk/rust/Cargo.toml b/risc0/zkvm/sdk/rust/Cargo.toml index 66917deedd..e6bba024f5 100644 --- a/risc0/zkvm/sdk/rust/Cargo.toml +++ b/risc0/zkvm/sdk/rust/Cargo.toml @@ -26,13 +26,20 @@ tempfile = "3.3" # # Host dependencies [target.'cfg(not(target_arch = "riscv32"))'.dependencies] +addr2line = { version = "0.18", optional = true } ctor = "0.1" cxx = "1.0" +gimli = { version = "0.26", optional = true } log = "0.4" rand = "0.8" risc0-zkvm-sys = { version = "0.11", path = "../.." } sha2 = "0.10" xmas-elf = "0.8" +prost = { version = "0.11", optional = true } + +[build-dependencies] +prost-build = { version = "0.11", optional = true } +protobuf-src = { version = "1.1", optional = true } [target.'cfg(not(target_arch = "riscv32"))'.dev-dependencies] env_logger = "0.9" @@ -47,6 +54,7 @@ host = ["risc0-zkp/host"] prove = ["circuit", "dep:lazy-regex", "risc0-zkp/prove", "risc0-zkvm-circuit/cpp"] std = ["risc0-zkp/std", "serde/std"] verify = ["circuit", "risc0-zkp/verify"] +profiler = ["dep:addr2line", "dep:gimli", "dep:prost", "dep:prost-build", "dep:protobuf-src"] # Run rust-based prover instead of FFI-based prover. pure-prove = [] diff --git a/risc0/zkvm/sdk/rust/build.rs b/risc0/zkvm/sdk/rust/build.rs index 6f8642843c..d3b15aaaf7 100644 --- a/risc0/zkvm/sdk/rust/build.rs +++ b/risc0/zkvm/sdk/rust/build.rs @@ -5,4 +5,10 @@ fn main() { println!("cargo:rustc-link-lib=static=risc0-zkp-sys"); println!("cargo:rustc-link-lib=static=risc0-zkvm-sys"); } + + #[cfg(feature = "profiler")] + { + std::env::set_var("PROTOC", protobuf_src::protoc()); + prost_build::compile_protos(&["src/host/profile.proto"], &["src/host/"]).unwrap(); + } } diff --git a/risc0/zkvm/sdk/rust/methods/BUILD.bazel b/risc0/zkvm/sdk/rust/methods/BUILD.bazel index 8ce3b883a8..47bbd4f7f7 100644 --- a/risc0/zkvm/sdk/rust/methods/BUILD.bazel +++ b/risc0/zkvm/sdk/rust/methods/BUILD.bazel @@ -1,12 +1,27 @@ -load("//bazel/rules/risc0:defs.bzl", "risc0_rust_method") +load("//bazel/rules/risc0:defs.bzl", "risc0_rust_library_pair", "risc0_rust_method") + +risc0_rust_library_pair( + name = "methods", + srcs = glob(["src/**/*.rs"]), + crate_name = "risc0_zkvm_methods", + guest_deps = [ + "@crates_guest//:serde", + ], + guest_features = ["bazel"], + host_deps = [ + "@crates_host//:serde", + ], + host_features = ["bazel"], + visibility = ["//visibility:public"], +) filegroup( name = "methods", srcs = [ "test_fail", "test_mem", + "test_multi_test", "test_sha", - "test_sha_accel", ], visibility = ["//visibility:public"], ) @@ -22,11 +37,12 @@ risc0_rust_method( ) risc0_rust_method( - name = "test_sha_accel", - srcs = ["inner/src/bin/sha_accel.rs"], + name = "test_multi_test", + srcs = ["inner/src/bin/multi_test.rs"], limit = 10, visibility = ["//visibility:public"], deps = [ + ":methods_guest", "//risc0/zkp/rust:zkp_guest", "//risc0/zkvm/sdk/rust/guest", ], diff --git a/risc0/zkvm/sdk/rust/methods/Cargo.toml b/risc0/zkvm/sdk/rust/methods/Cargo.toml index 2e3fa45dd9..01d1063f32 100644 --- a/risc0/zkvm/sdk/rust/methods/Cargo.toml +++ b/risc0/zkvm/sdk/rust/methods/Cargo.toml @@ -14,3 +14,6 @@ methods = ["inner"] [dependencies] serde = { version = "1.0", default-features = false, features = ["derive"] } + +[features] +pure-prove = ["risc0-build/pure-prove"] diff --git a/risc0/zkvm/sdk/rust/methods/inner/Cargo.toml b/risc0/zkvm/sdk/rust/methods/inner/Cargo.toml index e695e9ab4a..fd1e244e6f 100644 --- a/risc0/zkvm/sdk/rust/methods/inner/Cargo.toml +++ b/risc0/zkvm/sdk/rust/methods/inner/Cargo.toml @@ -19,6 +19,10 @@ serde = { version = "1.0", default-features = false, features = ["derive"] } lto = true opt-level = 3 +[profile.release.package.risc0-zkvm-methods-inner] +# Include debug symbols so we can test the profiler. +debug = 1 + [build-dependencies] risc0-build = { version = "0.11", path = "../../build" } diff --git a/risc0/zkvm/sdk/rust/methods/inner/src/bin/sha_accel.rs b/risc0/zkvm/sdk/rust/methods/inner/src/bin/multi_test.rs similarity index 64% rename from risc0/zkvm/sdk/rust/methods/inner/src/bin/sha_accel.rs rename to risc0/zkvm/sdk/rust/methods/inner/src/bin/multi_test.rs index 5d11b453f9..b65d563461 100644 --- a/risc0/zkvm/sdk/rust/methods/inner/src/bin/sha_accel.rs +++ b/risc0/zkvm/sdk/rust/methods/inner/src/bin/multi_test.rs @@ -12,31 +12,50 @@ // See the License for the specific language governing permissions and // limitations under the License. +//! Runs different tests based on the first u32 provided. + #![no_main] #![no_std] #![feature(alloc_error_handler)] +use core::arch::asm; + #[cfg(feature = "pure-prove")] use risc0_zkp::core::sha::{Digest, Sha}; use risc0_zkvm_guest::{env, sha}; +use risc0_zkvm_methods::multi_test::MultiTestSpec; risc0_zkvm_guest::entry!(main); risc0_zkvm_guest::standalone_handlers!(); +#[inline(never)] +#[no_mangle] +fn profile_test_func1() { + profile_test_func2() +} + +#[inline(always)] +#[no_mangle] +fn profile_test_func2() { + unsafe { asm!("nop") } +} + pub fn main() { - let impl_select: u32 = env::read(); + let impl_select: MultiTestSpec = env::read(); let data: &[u8] = env::read(); let digest = sha::digest_u8_slice(data); env::commit(&digest); match impl_select { - 0 => risc0_zkp::core::sha::testutil::test_sha_impl(&risc0_zkvm_guest::sha::Impl {}), + MultiTestSpec::ShaConforms => { + risc0_zkp::core::sha::testutil::test_sha_impl(&risc0_zkvm_guest::sha::Impl {}) + } #[cfg(feature = "pure-prove")] - 1 => { + MultiTestSpec::ShaInsecureConforms => { risc0_zkp::core::sha::testutil::test_sha_impl(&risc0_zkvm_guest::sha_insecure::Impl {}) } #[cfg(feature = "pure-prove")] - 2 => { + MultiTestSpec::ShaCycleCount => { // Time the simulated sha so that it estimates what we'd // see when it's a custom circuit. let a: &Digest = &Digest::new([1, 2, 3, 4, 5, 6, 7, 8]); @@ -57,6 +76,22 @@ pub fn main() { // our simulation doesn't run faster. assert!(total >= 72, "total: {total}"); } + #[cfg(feature = "pure-prove")] + MultiTestSpec::EventTrace => unsafe { + let mut _x: u32; + // Exeute some instructions with distinctive arguments + // that are easy to find in the event trace. + asm!(r" + li x5, 1337 + sw x5, 548(zero) +", out("x5") _x,); + }, + #[cfg(feature = "pure-prove")] + MultiTestSpec::Profiler => { + // Call an external function to make sure it's detected during profiling. + profile_test_func1() + } + #[cfg(not(feature = "pure-prove"))] _ => unimplemented!(), } } diff --git a/risc0/zkvm/sdk/rust/methods/inner/src/bin/verify.rs b/risc0/zkvm/sdk/rust/methods/inner/src/bin/verify.rs index 93b85fb933..a124524722 100644 --- a/risc0/zkvm/sdk/rust/methods/inner/src/bin/verify.rs +++ b/risc0/zkvm/sdk/rust/methods/inner/src/bin/verify.rs @@ -144,4 +144,8 @@ pub fn main() { verify_with_hal(&hal, method_id, seal).unwrap(); env::log("done"); + + // Avoid accidental cycle count regressions. + let cycles = env::get_cycle_count(); + assert!(cycles < 12_000_000, "Ran in {cycles} cycles; expecting under 12 million."); } diff --git a/risc0/zkvm/sdk/rust/methods/src/lib.rs b/risc0/zkvm/sdk/rust/methods/src/lib.rs index 7ddb0a81ad..6e25f4ae89 100644 --- a/risc0/zkvm/sdk/rust/methods/src/lib.rs +++ b/risc0/zkvm/sdk/rust/methods/src/lib.rs @@ -15,6 +15,7 @@ #![cfg_attr(not(feature = "std"), no_std)] pub mod bench; +pub mod multi_test; -#[cfg(not(target_os = "zkvm"))] +#[cfg(not(any(target_os = "zkvm", feature = "bazel")))] include!(concat!(env!("OUT_DIR"), "/methods.rs")); diff --git a/risc0/zkvm/sdk/rust/methods/src/multi_test.rs b/risc0/zkvm/sdk/rust/methods/src/multi_test.rs new file mode 100644 index 0000000000..8a13f6c568 --- /dev/null +++ b/risc0/zkvm/sdk/rust/methods/src/multi_test.rs @@ -0,0 +1,26 @@ +// Copyright 2022 Risc0, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions for test selection codes used by the "multi_test" test. + +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize)] +pub enum MultiTestSpec { + ShaConforms, + ShaInsecureConforms, + ShaCycleCount, + EventTrace, + Profiler, +} diff --git a/risc0/zkvm/sdk/rust/src/host/mod.rs b/risc0/zkvm/sdk/rust/src/host/mod.rs index 541e0f6b1a..41087652d8 100644 --- a/risc0/zkvm/sdk/rust/src/host/mod.rs +++ b/risc0/zkvm/sdk/rust/src/host/mod.rs @@ -30,17 +30,51 @@ pub use exception::Exception; use ffi as prove; pub use prove::{MethodId, Prover, Receipt}; +#[cfg(feature = "profiler")] +pub mod profiler; + /// The default digest count when generating a MethodId. pub const DEFAULT_METHOD_ID_LIMIT: u32 = 12; /// A Result specialized for [Exception]. pub type Result = std::result::Result; +/// An event traced from the running VM. +#[non_exhaustive] +#[derive(Debug)] +pub enum TraceEvent { + /// An instruction has started at the given program counter + InstructionStart { + /// Cycle number since startup + cycle: u32, + /// Program counter of the instruction being executed + pc: u32, + }, + + /// A register has been set + RegisterSet { + /// Register ID (0-16) + reg: usize, + /// New value in the register + value: u32, + }, + + /// A memory location has been written + MemorySet { + /// Address of word that's been written + addr: u32, + /// Value of word that's been written + value: u32, + }, +} + /// Options available to modify the prover's behavior. pub struct ProverOpts<'a> { pub(crate) skip_seal: bool, pub(crate) sendrecv_callbacks: HashMap Vec + 'a + Sync>>, + + pub(crate) trace_callback: Option anyhow::Result<()> + 'a>>, } impl<'a> ProverOpts<'a> { @@ -64,6 +98,16 @@ impl<'a> ProverOpts<'a> { .insert(channel_id, Box::new(callback)); self } + + /// Add a callback handler for raw trace messages. + pub fn with_trace_callback( + mut self, + callback: impl FnMut(TraceEvent) -> anyhow::Result<()> + 'a, + ) -> Self { + assert!(!self.trace_callback.is_some(), "Duplicate trace callback"); + self.trace_callback = Some(Box::new(callback)); + self + } } impl<'a> Default for ProverOpts<'a> { @@ -71,6 +115,7 @@ impl<'a> Default for ProverOpts<'a> { ProverOpts { skip_seal: false, sendrecv_callbacks: HashMap::new(), + trace_callback: None, } } } @@ -82,13 +127,15 @@ mod test { use anyhow::Result; use risc0_zkp::core::sha::Digest; use risc0_zkvm_methods::{ - FAIL_ID, FAIL_PATH, IO_ID, IO_PATH, SENDRECV_ID, SENDRECV_PATH, SHA_ACCEL_ID, - SHA_ACCEL_PATH, SHA_ID, SHA_PATH, + multi_test::MultiTestSpec, FAIL_ID, FAIL_PATH, IO_ID, IO_PATH, MULTI_TEST_ID, + MULTI_TEST_PATH, SENDRECV_ID, SENDRECV_PATH, SHA_ID, SHA_PATH, }; use risc0_zkvm_platform::memory::{COMMIT, HEAP}; use test_log::test; use super::{MethodId, Prover, ProverOpts, Receipt}; + #[cfg(feature = "pure-prove")] + use crate::host::TraceEvent; use crate::serde::{from_slice, to_vec}; #[test] @@ -263,12 +310,15 @@ mod test { #[test] fn sha_accel() { let opts = ProverOpts::default().with_skip_seal(true); - let mut prover = - Prover::new_with_opts(&std::fs::read(SHA_ACCEL_PATH).unwrap(), SHA_ACCEL_ID, opts) - .unwrap(); + let mut prover = Prover::new_with_opts( + &std::fs::read(MULTI_TEST_PATH).unwrap(), + MULTI_TEST_ID, + opts, + ) + .unwrap(); prover.add_input_u32_slice(&[ - 0, // Test risc0_zkvm_guest::sha::Impl - 0, // Compute an empty digest + MultiTestSpec::ShaConforms as _, // Test risc0_zkvm_guest::sha::Impl + 0, // Compute an empty digest ]); prover.run().unwrap(); } @@ -277,12 +327,15 @@ mod test { #[cfg(feature = "pure-prove")] fn insecure_sha_accel() { let opts = ProverOpts::default().with_skip_seal(true); - let mut prover = - Prover::new_with_opts(&std::fs::read(SHA_ACCEL_PATH).unwrap(), SHA_ACCEL_ID, opts) - .unwrap(); + let mut prover = Prover::new_with_opts( + &std::fs::read(MULTI_TEST_PATH).unwrap(), + MULTI_TEST_ID, + opts, + ) + .unwrap(); prover.add_input_u32_slice(&[ - 1, // Test risc0_zkvm_guest::sha_insecure::Impl - 0, // Compute an empty digest + MultiTestSpec::ShaInsecureConforms as _, // Test risc0_zkvm_guest::sha_insecure::Impl + 0, // Compute an empty digest ]); prover.run().unwrap(); } @@ -291,16 +344,159 @@ mod test { #[cfg(feature = "pure-prove")] fn sha_cycle_count() { let opts = ProverOpts::default().with_skip_seal(true); - let mut prover = - Prover::new_with_opts(&std::fs::read(SHA_ACCEL_PATH).unwrap(), SHA_ACCEL_ID, opts) - .unwrap(); + let mut prover = Prover::new_with_opts( + &std::fs::read(MULTI_TEST_PATH).unwrap(), + MULTI_TEST_ID, + opts, + ) + .unwrap(); prover.add_input_u32_slice(&[ - 2, // Check insecure cycle count < expected from accel - 0, // Compute an empty digest + MultiTestSpec::ShaCycleCount as _, // Check insecure cycle count < expected from accel + 0, // Compute an empty digest ]); prover.run().unwrap(); } + #[test] + #[cfg(feature = "pure-prove")] + fn profiler() { + use crate::host::profiler::Frame; + let elf_contents = std::fs::read(MULTI_TEST_PATH).unwrap(); + let mut prof = + crate::host::profiler::Profiler::new("multi_test.elf", &elf_contents).unwrap(); + { + let opts = ProverOpts::default() + .with_skip_seal(true) + .with_trace_callback(prof.make_trace_callback()); + let mut prover = Prover::new_with_opts(&elf_contents, MULTI_TEST_ID, opts).unwrap(); + prover.add_input_u32_slice(&[ + MultiTestSpec::Profiler as _, // Generate known profiling trace + 0, + ]); + prover.run().unwrap(); + } + + prof.finalize(); + + // Gather up anything containing our profile_test functions. + // If the test doesn't pass, we don't want to display the + // whole profiling structure. + let occurences: Vec<_> = prof + .iter() + .filter(|(frames, _addr, _count)| { + frames.iter().any(|fr| fr.name.contains("profile_test")) + }) + .collect(); + + assert!(!occurences.is_empty(), "{:#?}", Vec::from_iter(prof.iter())); + + let elf_mem = crate::elf::Program::load_elf(elf_contents.as_slice(), u32::MAX) + .unwrap() + .image; + + assert!( + occurences.iter().any(|(fr, addr, _count)| { + match fr.as_slice() { + [fr1 @ Frame { + name: name1, + filename: fn1, + .. + }, fr2 @ Frame { + name: name2, + filename: fn2, + .. + }] => { + println!("Inspecting frames:\n{fr1:?}\n{fr2:?}\n"); + if name1 != "profile_test_func2" || name2 != "profile_test_func1" { + println!("Names did not match: {}, {}", name1, name2); + return false; + } + if !fn1.ends_with("multi_test.rs") || !fn2.ends_with("multi_test.rs") { + println!("Filenames did not match: {}, {}", fn1, fn2); + return false; + } + // Check to make sure we hit the "nop" instruction + match elf_mem.get(&(*addr as u32)) { + None => { + println!("Addr {addr} not present in elf"); + return false; + } + Some(0x00_00_00_13) => (), + Some(inst) => { + println!("Looking for 'nop'; got 0x{inst:08x}"); + return false; + } + } + + // All checks passed; this is the occurence we were looking for. + true + } + _ => { + println!("{:#?}", fr); + false + } + } + }), + "{:#?}", + occurences + ); + } + + #[test] + #[cfg(feature = "pure-prove")] + fn trace() { + let mut events: Vec = Vec::new(); + { + let opts = ProverOpts::default() + .with_skip_seal(true) + .with_trace_callback(|event| Ok(events.push(event))); + let mut prover = Prover::new_with_opts( + &std::fs::read(MULTI_TEST_PATH).unwrap(), + MULTI_TEST_ID, + opts, + ) + .unwrap(); + prover.add_input_u32_slice(&[ + MultiTestSpec::EventTrace as _, // Generate known trace events + 0, + ]); + prover.run().unwrap(); + } + let occurances = events + .windows(4) + .filter_map(|window| { + if let &[TraceEvent::InstructionStart { + // li x5, 1337 + cycle: cycle1, + pc: pc1, + }, TraceEvent::RegisterSet { + reg: 5, + value: 1337, + }, TraceEvent::InstructionStart { + // sw x5, 548(zero) + cycle: cycle2, + pc: pc2, + }, TraceEvent::MemorySet { + addr: 548, + value: 1337, + }] = window + { + assert_eq!(cycle1 + 3, cycle2, "li should take 3 cycles: {:#?}", window); + assert_eq!( + pc1 + 4, + pc2, + "program counter should advance one word: {:#?}", + window + ); + Some(()) + } else { + None + } + }) + .count(); + assert_eq!(occurances, 1, "trace events: {:#?}", &events); + } + #[test] #[cfg(feature = "pure-prove")] fn recursion() { diff --git a/risc0/zkvm/sdk/rust/src/host/profile.proto b/risc0/zkvm/sdk/rust/src/host/profile.proto new file mode 100644 index 0000000000..ee0391f5e8 --- /dev/null +++ b/risc0/zkvm/sdk/rust/src/host/profile.proto @@ -0,0 +1,212 @@ +// Copyright 2016 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Profile is a common stacktrace profile format. +// +// Measurements represented with this format should follow the +// following conventions: +// +// - Consumers should treat unset optional fields as if they had been +// set with their default value. +// +// - When possible, measurements should be stored in "unsampled" form +// that is most useful to humans. There should be enough +// information present to determine the original sampled values. +// +// - On-disk, the serialized proto must be gzip-compressed. +// +// - The profile is represented as a set of samples, where each sample +// references a sequence of locations, and where each location belongs +// to a mapping. +// - There is a N->1 relationship from sample.location_id entries to +// locations. For every sample.location_id entry there must be a +// unique Location with that id. +// - There is an optional N->1 relationship from locations to +// mappings. For every nonzero Location.mapping_id there must be a +// unique Mapping with that id. + +syntax = "proto3"; + +package perftools.profiles; + +option java_package = "com.google.perftools.profiles"; +option java_outer_classname = "ProfileProto"; + +message Profile { + // A description of the samples associated with each Sample.value. + // For a cpu profile this might be: + // [["cpu","nanoseconds"]] or [["wall","seconds"]] or [["syscall","count"]] + // For a heap profile, this might be: + // [["allocations","count"], ["space","bytes"]], + // If one of the values represents the number of events represented + // by the sample, by convention it should be at index 0 and use + // sample_type.unit == "count". + repeated ValueType sample_type = 1; + // The set of samples recorded in this profile. + repeated Sample sample = 2; + // Mapping from address ranges to the image/binary/library mapped + // into that address range. mapping[0] will be the main binary. + repeated Mapping mapping = 3; + // Useful program location + repeated Location location = 4; + // Functions referenced by locations + repeated Function function = 5; + // A common table for strings referenced by various messages. + // string_table[0] must always be "". + repeated string string_table = 6; + // frames with Function.function_name fully matching the following + // regexp will be dropped from the samples, along with their successors. + int64 drop_frames = 7; // Index into string table. + // frames with Function.function_name fully matching the following + // regexp will be kept, even if it matches drop_frames. + int64 keep_frames = 8; // Index into string table. + + // The following fields are informational, do not affect + // interpretation of results. + + // Time of collection (UTC) represented as nanoseconds past the epoch. + int64 time_nanos = 9; + // Duration of the profile, if a duration makes sense. + int64 duration_nanos = 10; + // The kind of events between sampled ocurrences. + // e.g [ "cpu","cycles" ] or [ "heap","bytes" ] + ValueType period_type = 11; + // The number of events between sampled occurrences. + int64 period = 12; + // Freeform text associated to the profile. + repeated int64 comment = 13; // Indices into string table. + // Index into the string table of the type of the preferred sample + // value. If unset, clients should default to the last sample value. + int64 default_sample_type = 14; +} + +// ValueType describes the semantics and measurement units of a value. +message ValueType { + int64 type = 1; // Index into string table. + int64 unit = 2; // Index into string table. +} + +// Each Sample records values encountered in some program +// context. The program context is typically a stack trace, perhaps +// augmented with auxiliary information like the thread-id, some +// indicator of a higher level request being handled etc. +message Sample { + // The ids recorded here correspond to a Profile.location.id. + // The leaf is at location_id[0]. + repeated uint64 location_id = 1; + // The type and unit of each value is defined by the corresponding + // entry in Profile.sample_type. All samples must have the same + // number of values, the same as the length of Profile.sample_type. + // When aggregating multiple samples into a single sample, the + // result has a list of values that is the element-wise sum of the + // lists of the originals. + repeated int64 value = 2; + // label includes additional context for this sample. It can include + // things like a thread id, allocation size, etc + repeated Label label = 3; +} + +message Label { + int64 key = 1; // Index into string table + + // At most one of the following must be present + int64 str = 2; // Index into string table + int64 num = 3; + + // Should only be present when num is present. + // Specifies the units of num. + // Use arbitrary string (for example, "requests") as a custom count unit. + // If no unit is specified, consumer may apply heuristic to deduce the unit. + // Consumers may also interpret units like "bytes" and "kilobytes" as memory + // units and units like "seconds" and "nanoseconds" as time units, + // and apply appropriate unit conversions to these. + int64 num_unit = 4; // Index into string table +} + +message Mapping { + // Unique nonzero id for the mapping. + uint64 id = 1; + // Address at which the binary (or DLL) is loaded into memory. + uint64 memory_start = 2; + // The limit of the address range occupied by this mapping. + uint64 memory_limit = 3; + // Offset in the binary that corresponds to the first mapped address. + uint64 file_offset = 4; + // The object this entry is loaded from. This can be a filename on + // disk for the main binary and shared libraries, or virtual + // abstractions like "[vdso]". + int64 filename = 5; // Index into string table + // A string that uniquely identifies a particular program version + // with high probability. E.g., for binaries generated by GNU tools, + // it could be the contents of the .note.gnu.build-id field. + int64 build_id = 6; // Index into string table + + // The following fields indicate the resolution of symbolic info. + bool has_functions = 7; + bool has_filenames = 8; + bool has_line_numbers = 9; + bool has_inline_frames = 10; +} + +// Describes function and line table debug information. +message Location { + // Unique nonzero id for the location. A profile could use + // instruction addresses or any integer sequence as ids. + uint64 id = 1; + // The id of the corresponding profile.Mapping for this location. + // It can be unset if the mapping is unknown or not applicable for + // this profile type. + uint64 mapping_id = 2; + // The instruction address for this location, if available. It + // should be within [Mapping.memory_start...Mapping.memory_limit] + // for the corresponding mapping. A non-leaf address may be in the + // middle of a call instruction. It is up to display tools to find + // the beginning of the instruction if necessary. + uint64 address = 3; + // Multiple line indicates this location has inlined functions, + // where the last entry represents the caller into which the + // preceding entries were inlined. + // + // E.g., if memcpy() is inlined into printf: + // line[0].function_name == "memcpy" + // line[1].function_name == "printf" + repeated Line line = 4; + // Provides an indication that multiple symbols map to this location's + // address, for example due to identical code folding by the linker. In that + // case the line information above represents one of the multiple + // symbols. This field must be recomputed when the symbolization state of the + // profile changes. + bool is_folded = 5; +} + +message Line { + // The id of the corresponding profile.Function for this line. + uint64 function_id = 1; + // Line number in source code. + int64 line = 2; +} + +message Function { + // Unique nonzero id for the function. + uint64 id = 1; + // Name of the function, in human-readable form if available. + int64 name = 2; // Index into string table + // Name of the function, as identified by the system. + // For instance, it can be a C++ mangled name. + int64 system_name = 3; // Index into string table + // Source file containing the function. + int64 filename = 4; // Index into string table + // Line number in source file. + int64 start_line = 5; +} diff --git a/risc0/zkvm/sdk/rust/src/host/profiler.rs b/risc0/zkvm/sdk/rust/src/host/profiler.rs new file mode 100644 index 0000000000..fde2e50fd8 --- /dev/null +++ b/risc0/zkvm/sdk/rust/src/host/profiler.rs @@ -0,0 +1,291 @@ +// Copyright 2022 Risc0, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Support for profiling the guest. +//! +//! This counts cycles spent at each location when executing the +//! guest. This is a preliminary experimental implementation. It +//! does not trace full stack traces, but only provides the top level +//! stack frame. (More than one stack frame may show up in the case +//! of inlined functions). + +// TODO: +// +// * Count the full stack instead of the top frame; the "gimli" +// crate's UnwindSection and Evaluation should help us do this +// +// * Demangle symbols + +use std::collections::HashMap; + +use addr2line::{ + object::{read::File, Object, ObjectSegment}, + Context, +}; +use anyhow::Result; +use gimli::{EndianRcSlice, RunTimeEndian}; +use prost::Message; + +use crate::host::TraceEvent; + +mod proto { + // Generated proto interface. + include!(concat!(env!("OUT_DIR"), "/perftools.profiles.rs")); +} + +/// Manages profiling state +pub struct Profiler { + // Current program counter + pc: u32, + + // Cycle count when the last instruction started + cycle: u32, + + // Counts per program counter + counts: HashMap, + + ctx: Context>, + + profile: ProfileBuilder, +} + +/// Represents a frame. Prefer to export the whole profiler proto using +/// profiler.as_protobuf(). +#[derive(Debug)] +pub struct Frame { + /// Function name + pub name: String, + /// Line number + pub lineno: i64, + /// Filename where this function is defined + pub filename: String, +} + +fn decode_frame(fr: addr2line::Frame>) -> Option { + Some(Frame { + name: fr.function.as_ref()?.raw_name().ok()?.to_string(), + lineno: fr.location.as_ref()?.line? as i64, + filename: fr.location.as_ref()?.file?.to_string(), + }) +} + +fn lookup_pc(pc: u32, ctx: &Context>) -> Vec { + use addr2line::fallible_iterator::FallibleIterator; + match ctx.find_frames(pc as u64) { + Ok(frames) => frames + .filter_map(|fr| Ok(decode_frame(fr))) + .collect::>() + .unwrap(), + Err(err) => { + eprintln!("Error finding frames! {:?}", err); + [].into() + } + } +} + +impl Profiler { + /// Return a new profile from the given RISCV ELF. + pub fn new(filename: &str, elf_data: &[u8]) -> Result { + let file = File::parse(elf_data)?; + let ctx = Context::new(&file)?; + let mut profiler = Profiler { + pc: u32::MAX, + cycle: 0, + counts: HashMap::new(), + ctx, + profile: ProfileBuilder::new(), + }; + + // Save the main binary name + let bin_name = profiler.profile.get_string(filename); + for segment in file.segments() { + if segment.address() == risc0_zkvm_platform::memory::PROG.start() as u64 { + profiler.profile.profile.mapping.push(proto::Mapping { + id: 1, + memory_start: segment.address(), + memory_limit: segment.address() + segment.size(), + file_offset: segment.file_range().0, + filename: bin_name, + ..Default::default() + }); + } + } + + Ok(profiler) + } + + /// Dereferences strings, etc. in the protobuf for testing purposes. + /// Returns a tuple of (frames, program counter, cycles) + pub fn iter(&self) -> impl Iterator, usize, usize)> + '_ { + self.profile.iter() + } + + /// Returns a callback to populate this profiler, suitable for + /// passing to ProverOpts::with_trace_callback. + pub fn make_trace_callback<'a>( + &'a mut self, + ) -> impl FnMut(TraceEvent) -> anyhow::Result<()> + 'a { + |event| { + match event { + TraceEvent::InstructionStart { cycle, pc } => { + // Count against the last program counter. + let cycles = cycle - self.cycle; + let orig_pc = self.pc; + *self.counts.entry(orig_pc).or_insert(0) += cycles as usize; + self.pc = pc; + self.cycle = cycle; + } + _ => (), + } + Ok(()) + } + } + + /// Count and save the profiling samples + pub fn finalize(&mut self) { + if !self.profile.profile.sample.is_empty() { + return; + } + + for (pc, count) in self.counts.iter() { + let frames = lookup_pc(*pc, &self.ctx); + let loc = proto::Location { + address: *pc as u64, + line: frames + .into_iter() + .map(|fr| proto::Line { + function_id: self.profile.get_function(&fr.name, &fr.filename), + line: fr.lineno, + }) + .collect(), + ..Default::default() + }; + let sample = proto::Sample { + location_id: vec![self.profile.get_location(loc)], + value: vec![*count as i64], + ..Default::default() + }; + self.profile.add_sample(sample); + } + } + + /// Returns the result of this profiling run as a protobuf. + pub fn as_protobuf(&self) -> &proto::Profile { + assert!( + !self.profile.profile.sample.is_empty(), + "Call finalize() first to generate the protobuf" + ); + &self.profile.profile + } + + /// Returns the result of this profiling run, encoded and ready for writing + /// to a file. + pub fn encode_to_vec(&mut self) -> Vec { + self.as_protobuf().encode_to_vec() + } +} + +struct ProfileBuilder { + strings: HashMap, + + functions: HashMap<(String, String), u64>, + + profile: proto::Profile, +} + +impl ProfileBuilder { + fn new() -> Self { + let mut builder = Self { + strings: HashMap::new(), + functions: HashMap::new(), + profile: Default::default(), + }; + + // First string must always be the empty string + assert_eq!(0, builder.get_string("")); + + // Set up defaults for us + let sample_type = proto::ValueType { + r#type: builder.get_string("cycles"), + unit: builder.get_string("count"), + ..Default::default() + }; + builder.profile.sample_type.push(sample_type); + + builder + } + + fn get_location(&mut self, mut loc: proto::Location) -> u64 { + let id = self.profile.location.len() as u64 + 1; + loc.id = id; + if !self.profile.mapping.is_empty() { + loc.mapping_id = 1; + } + self.profile.location.push(loc); + id + } + + fn get_function(&mut self, name: &str, filename: &str) -> u64 { + let key = (name.to_string(), filename.to_string()); + if let Some(&id) = self.functions.get(&key) { + return id; + } + + let id = self.profile.function.len() as u64 + 1; + let name = self.get_string(name); + let filename = self.get_string(filename); + self.profile.function.push(proto::Function { + id, + name, + filename, + ..Default::default() + }); + self.functions.insert(key, id); + id + } + + // Adds a string to the string table + fn get_string(&mut self, s: &str) -> i64 { + *self.strings.entry(s.to_string()).or_insert_with(|| { + let idx = self.profile.string_table.len() as i64; + self.profile.string_table.push(s.to_string()); + idx + }) + } + + fn add_sample(&mut self, sample: proto::Sample) { + self.profile.sample.push(sample) + } + + fn iter(&self) -> impl Iterator, usize, usize)> + '_ { + self.profile.sample.iter().map(|sample| { + let loc = &self.profile.location[sample.location_id[0] as usize - 1]; + ( + loc.line + .iter() + .map(|line| { + let func = &self.profile.function[line.function_id as usize - 1]; + Frame { + name: self.profile.string_table[func.name as usize].clone(), + lineno: line.line, + filename: self.profile.string_table[func.filename as usize].clone(), + } + }) + .collect(), + loc.address as usize, + sample.value[0] as usize, + ) + }) + } +} diff --git a/risc0/zkvm/sdk/rust/src/prove/exec.rs b/risc0/zkvm/sdk/rust/src/prove/exec.rs index 1b9ce58278..6adad14635 100644 --- a/risc0/zkvm/sdk/rust/src/prove/exec.rs +++ b/risc0/zkvm/sdk/rust/src/prove/exec.rs @@ -51,12 +51,14 @@ use risc0_zkvm_platform::{ }; use super::ffpu::ffpu_execute; -use crate::{elf::Program, CIRCUIT}; +use crate::{elf::Program, host::TraceEvent, CIRCUIT}; pub trait IoHandler { + fn is_trace_enabled(&self) -> bool; fn on_commit(&mut self, buf: &[u32]) -> Result<()>; fn on_fault(&mut self, msg: &str) -> Result<()>; fn on_txrx(&mut self, channel: u32, buf: &[u8]) -> Result>; + fn on_trace(&mut self, event: TraceEvent) -> Result<()>; } #[derive(Clone, PartialEq, Eq)] @@ -285,6 +287,7 @@ pub struct MachineContext<'a, H: IoHandler> { memory: MemoryState, io: &'a mut H, cur_host_to_guest_offset: usize, + trace_enabled: bool, } impl PartialOrd for MemoryEvent { @@ -367,6 +370,7 @@ impl<'a, H: IoHandler> MachineContext<'a, H> { pub fn new(io: &'a mut H) -> Self { MachineContext { memory: MemoryState::new(), + trace_enabled: io.is_trace_enabled(), io, cur_host_to_guest_offset: INPUT.start(), } @@ -383,7 +387,31 @@ impl<'a, H: IoHandler> MachineContext<'a, H> { (split_word(quot), split_word(rem)) } - fn log(&self, msg: &str, args: &[Fp]) { + fn extract_trace(&mut self, message: &str, args: &[Fp]) -> Result<()> { + match message { + msg if msg.starts_with("C%u: pc: %08x Decode") => { + let (cycle, pc) = (args[0], args[1]); + self.io.on_trace(TraceEvent::InstructionStart { + cycle: cycle.into(), + pc: pc.into(), + })? + } + "C%u: pc: %08x Final: 0x%04x%04x -> r%u, next: %08x" => { + let (val_high, val_low, reg) = (args[2], args[3], args[4]); + self.io.on_trace(TraceEvent::RegisterSet { + reg: u32::from(reg) as usize, + value: ((u32::from(val_high)) * (1 << 16) + (u32::from(val_low))), + })? + } + _ => (), + } + Ok(()) + } + + fn log(&mut self, msg: &str, args: &[Fp]) { + if self.trace_enabled { + self.extract_trace(msg, args).unwrap(); + } if log::max_level() < log::LevelFilter::Trace { // Don't bother to format it if we're not even logging. return; @@ -472,6 +500,12 @@ impl<'a, H: IoHandler> MachineContext<'a, H> { } }; self.on_write(cycle, addr * 4, data)?; + if self.trace_enabled { + self.io.on_trace(TraceEvent::MemorySet { + addr: addr * 4, + value: data, + })? + } Ok(()) } diff --git a/risc0/zkvm/sdk/rust/src/prove/mod.rs b/risc0/zkvm/sdk/rust/src/prove/mod.rs index 3243b90e49..56c51fe24d 100644 --- a/risc0/zkvm/sdk/rust/src/prove/mod.rs +++ b/risc0/zkvm/sdk/rust/src/prove/mod.rs @@ -30,7 +30,13 @@ use risc0_zkvm_platform::{ }; use self::cpu_eval::CpuEvalCheck; -use crate::{elf::Program, host::ProverOpts, method_id::MethodId, receipt::Receipt, CIRCUIT}; +use crate::{ + elf::Program, + host::{ProverOpts, TraceEvent}, + method_id::MethodId, + receipt::Receipt, + CIRCUIT, +}; pub struct Prover<'a> { elf: Program, @@ -144,6 +150,18 @@ impl<'a> exec::IoHandler for ProverImpl<'a> { } } + fn is_trace_enabled(&self) -> bool { + self.opts.trace_callback.is_some() + } + + fn on_trace(&mut self, event: TraceEvent) -> Result<()> { + if let Some(ref mut cb) = self.opts.trace_callback { + cb(event) + } else { + Ok(()) + } + } + fn on_commit(&mut self, buf: &[u32]) -> Result<()> { self.commit.extend_from_slice(buf); Ok(())